diff --git a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Helper.cs b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Helper.cs index c1730a4ec..508cec065 100644 --- a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Helper.cs +++ b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Helper.cs @@ -598,7 +598,7 @@ public static List CreateImmutable ColumnBit = false, ColumnDateTime = EpocDate.AddDays(index), ColumnDateTime2 = DateTime.UtcNow.AddDays(index), - ColumnDecimal = index+1, + ColumnDecimal = index + 1, ColumnFloat = index + 1, ColumnInt = index + 1, ColumnNVarChar = $"NVARCHAR{index}-Property" @@ -901,6 +901,42 @@ public static List CreateDottedTables(int count) #endregion + #region NonKeyedTable + + /// + /// Creates a list of objects. + /// + /// The number of rows. + /// A list of objects. + public static IEnumerable CreateNonKeyedTables(int count = 10) + { + for (var index = 0; index < count; index++) + { + yield return new NonKeyedTable + { + ColumnDateTime2 = DateTime.UtcNow, + ColumnInt = index, + ColumnNVarChar = $"NVARCHAR{index}" + }; + } + } + + /// + /// Creates an instance of object. + /// + /// A new created instance of object. + public static NonKeyedTable CreateNonKeyedTable() + { + return new NonKeyedTable + { + ColumnDateTime2 = DateTime.UtcNow, + ColumnInt = new Random().Next(int.MinValue, int.MaxValue), + ColumnNVarChar = Guid.NewGuid().ToString() + }; + } + + #endregion + #region Dynamics #region IdentityTable @@ -983,53 +1019,6 @@ public static Tuple, IEnumerable> CreateDynamicIdentityTabl return new Tuple, IEnumerable>(tables, fields); } - /// - /// Creates a list of dynamic objects for [sc].[IdentityTable] without an identity. - /// - /// The number of rows. - /// A list of dynamic objects. - public static List CreateDynamicIdentityTablesWithoutIdentity(int count) - { - var tables = new List(); - for (var i = 0; i < count; i++) - { - var index = i + 1; - tables.Add(new - { - RowGuid = Guid.NewGuid(), - ColumnBit = true, - ColumnDateTime = EpocDate.AddDays(index), - ColumnDateTime2 = DateTime.UtcNow, - ColumnDecimal = Convert.ToDecimal(index), - ColumnFloat = Convert.ToDouble(index), - ColumnInt = index, - ColumnNVarChar = $"NVARCHAR{index}" - }); - } - return tables; - } - - /// - /// Creates a list of dynamic objects for [dbo].[NonKeyedTable] table. - /// - /// The number of rows. - /// A list of dynamic objects. - public static List CreateDynamicNonKeyedTables(int count) - { - var tables = new List(); - for (var i = 0; i < count; i++) - { - var index = i + 1; - tables.Add(new - { - ColumnDateTime2 = DateTime.UtcNow, - ColumnInt = index, - ColumnNVarChar = $"NVARCHAR{index}" - }); - } - return tables; - } - #endregion #region NonIdentityTable @@ -1111,6 +1100,45 @@ public static Tuple, IEnumerable> CreateDynamicNonIdentityT #endregion + #region NonKeyedTable + + /// + /// Creates a list of dynamic objects for [dbo].[NonKeyedTable] table. + /// + /// The number of rows. + /// A list of dynamic objects. + public static List CreateDynamicNonKeyedTables(int count) + { + var tables = new List(); + for (var i = 0; i < count; i++) + { + var index = i + 1; + tables.Add(new + { + ColumnDateTime2 = DateTime.UtcNow, + ColumnInt = index, + ColumnNVarChar = $"NVARCHAR{index}" + }); + } + return tables; + } + + /// + /// Creates an instance of dynamic object for [dbo].[NonKeyedTable] table. + /// + /// An instance of dynamic object. + public static dynamic CreateDynamicNonKeyedTable() + { + return new + { + ColumnDateTime2 = DateTime.UtcNow, + ColumnInt = new Random().Next(int.MinValue, int.MaxValue), + ColumnNVarChar = Guid.NewGuid().ToString() + }; + } + + #endregion + #endregion } } diff --git a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/MergeAllTest.cs b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/MergeAllTest.cs index 7ef2b4272..804bf7d15 100644 --- a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/MergeAllTest.cs +++ b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/MergeAllTest.cs @@ -1514,16 +1514,16 @@ public void TestSqlConnectionMergeAllViaTableNameForNonIdentityEmptyTableWithHin } } - [TestMethod, ExpectedException(typeof(InvalidQualifiersException))] + [TestMethod, ExpectedException(typeof(KeyFieldNotFoundException))] public void ThrowExceptionOnSqlConnectionMergeAllIfTheKeyFieldIsNotPresent() { // Setup - var tables = Helper.CreateDynamicIdentityTablesWithoutIdentity(10); + var tables = Helper.CreateDynamicNonKeyedTables(10); using (var connection = new SqlConnection(Database.ConnectionStringForRepoDb)) { // Act - connection.MergeAll(ClassMappedNameCache.Get(), tables); + connection.MergeAll(ClassMappedNameCache.Get(), tables); } } @@ -2002,12 +2002,12 @@ public void TestSqlConnectionMergeAllAsyncViaTableNameForNonIdentityEmptyTableWi public void ThrowExceptionOnSqlConnectionMergeAllAsyncIfTheKeyFieldIsNotPresent() { // Setup - var tables = Helper.CreateDynamicIdentityTablesWithoutIdentity(10); + var tables = Helper.CreateNonKeyedTables(10); using (var connection = new SqlConnection(Database.ConnectionStringForRepoDb)) { // Act - connection.MergeAllAsync(ClassMappedNameCache.Get(), tables).Wait(); + connection.MergeAllAsync(ClassMappedNameCache.Get(), tables).Wait(); } } diff --git a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/QueryTest.cs b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/QueryTest.cs index f018eea99..fccd211f9 100644 --- a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/QueryTest.cs +++ b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/QueryTest.cs @@ -390,6 +390,26 @@ public void TestSqlConnectionQueryViaExpression() } } + [TestMethod] + public void TestSqlConnectionQueryViaExpressionWitNullValue() + { + // Setup + var tables = Helper.CreateIdentityTables(10); + var last = tables.Last(); + + using (var connection = new SqlConnection(Database.ConnectionStringForRepoDb)) + { + // Act + connection.InsertAll(tables); + + // Act + var result = connection.Query(c => c.ColumnNVarChar == null); + + // Assert + Assert.AreEqual(0, result.Count()); + } + } + [TestMethod] public void TestSqlConnectionQueryViaQueryField() { @@ -1545,6 +1565,26 @@ public void TestSqlConnectionQueryAsyncViaExpression() } } + [TestMethod] + public void TestSqlConnectionQueryAsyncViaExpressionWithNullValue() + { + // Setup + var tables = Helper.CreateIdentityTables(10); + var last = tables.Last(); + + using (var connection = new SqlConnection(Database.ConnectionStringForRepoDb)) + { + // Act + connection.InsertAll(tables); + + // Act + var result = connection.QueryAsync(c => c.ColumnNVarChar == null).Result; + + // Assert + Assert.AreEqual(0, result.Count()); + } + } + [TestMethod] public void TestSqlConnectionQueryAsyncViaQueryField() { diff --git a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/UpdateTest.cs b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/UpdateTest.cs index 978dde0b5..031426141 100644 --- a/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/UpdateTest.cs +++ b/RepoDb.Core/RepoDb.Tests/RepoDb.IntegrationTests/Operations/UpdateTest.cs @@ -1651,30 +1651,45 @@ public void TestSqlConnectionUpdateViaTableNameWithHints() } } + [TestMethod, ExpectedException(typeof(KeyFieldNotFoundException))] + public void ThrowExceptionOnSqlConnectionUpdateIfThereIsNoKeyField() + { + // Setup + var data = Helper.CreateNonKeyedTable(); + + using (var connection = new SqlConnection(Database.ConnectionStringForRepoDb)) + { + // Act + connection.Update(data); + } + } + [TestMethod, ExpectedException(typeof(KeyFieldNotFoundException))] public void ThrowExceptionOnSqlConnectionUpdateViaTableNameIfThereIsNoKeyField() { + // Setup + var data = Helper.CreateDynamicNonKeyedTable(); + using (var connection = new SqlConnection(Database.ConnectionStringForRepoDb)) { - var data = new - { - ColumnInt = 1, - ColumnDecimal = 2 - }; - connection.Update(ClassMappedNameCache.Get(), data); + // Act + connection.Update(ClassMappedNameCache.Get(), (object)data); } } [TestMethod, ExpectedException(typeof(EmptyException))] public void ThrowExceptionOnSqlConnectionUpdateViaTableNameIfTheFieldsAreNotFound() { + // Setup + var data = new + { + Id = 1, + AnyField = 1 + }; + using (var connection = new SqlConnection(Database.ConnectionStringForRepoDb)) { - var data = new - { - Id = 1, - AnyField = 1 - }; + // Act connection.Update(ClassMappedNameCache.Get(), data); } } diff --git a/RepoDb.Core/RepoDb/Extensions/DbCommandExtension.cs b/RepoDb.Core/RepoDb/Extensions/DbCommandExtension.cs index 02038e0dc..65166804a 100644 --- a/RepoDb.Core/RepoDb/Extensions/DbCommandExtension.cs +++ b/RepoDb.Core/RepoDb/Extensions/DbCommandExtension.cs @@ -237,10 +237,10 @@ private static void CreateParametersInternal(IDbCommand command, var dbField = GetDbField(name, dbFields); var value = classProperty.PropertyInfo.GetValue(param); var returnType = (Type)null; - var dbType = (DbType?)null; // Propertyhandler - var definition = InvokePropertyHandlerSetMethod(classProperty.GetPropertyHandler(), value, classProperty); + var propertyHandler = GetProperyHandler(classProperty, value?.GetType()); + var definition = InvokePropertyHandlerSetMethod(propertyHandler, value, classProperty); if (definition != null) { returnType = definition.ReturnType.GetUnderlyingType(); @@ -256,15 +256,9 @@ private static void CreateParametersInternal(IDbCommand command, } // DbType - if (returnType != null) - { - dbType = clientTypeToDbTypeResolver.Resolve(returnType); - } - else - { - dbType = classProperty.GetDbType() ?? - value?.GetType()?.GetDbType(); - } + var dbType = (returnType != null ? clientTypeToDbTypeResolver.Resolve(returnType) : null) ?? + classProperty.GetDbType() ?? + value?.GetType()?.GetDbType(); // Add the parameter command.Parameters.Add(command.CreateParameter(name, value, dbType)); @@ -308,7 +302,7 @@ private static void CreateParameters(IDbCommand command, } // Propertyhandler - var propertyHandler = classProperty?.GetPropertyHandler() ?? PropertyHandlerCache.Get(valueType); + var propertyHandler = GetProperyHandler(classProperty, valueType); var definition = InvokePropertyHandlerSetMethod(propertyHandler, value, classProperty); if (definition != null) { @@ -325,7 +319,7 @@ private static void CreateParameters(IDbCommand command, } // DbType - var dbType = clientTypeToDbTypeResolver.Resolve(valueType) ?? + var dbType = (valueType != null ? clientTypeToDbTypeResolver.Resolve(valueType) : null) ?? classProperty?.GetDbType() ?? value?.GetType()?.GetDbType(); @@ -430,7 +424,7 @@ private static void CreateParameters(this IDbCommand command, // PropertyHandler var classProperty = PropertyCache.Get(entityType, queryField.Field); - var propertyHandler = classProperty?.GetPropertyHandler() ?? PropertyHandlerCache.Get(valueType); + var propertyHandler = GetProperyHandler(classProperty, valueType); var definition = InvokePropertyHandlerSetMethod(propertyHandler, value, classProperty); if (definition != null) { @@ -447,7 +441,7 @@ private static void CreateParameters(this IDbCommand command, } // DbType - var dbType = clientTypeToDbTypeResolver.Resolve(valueType) ?? + var dbType = (valueType != null ? clientTypeToDbTypeResolver.Resolve(valueType) : null) ?? classProperty?.GetDbType() ?? value?.GetType()?.GetDbType(); @@ -481,7 +475,7 @@ private static void CreateParametersForInOperation(this IDbCommand command, var valueType = value?.GetType()?.GetUnderlyingType(); // Propertyhandler - var properyHandler = PropertyHandlerCache.Get(valueType); + var properyHandler = GetProperyHandler(null, valueType); var definition = InvokePropertyHandlerSetMethod(properyHandler, value, null); if (definition != null) { @@ -498,7 +492,7 @@ private static void CreateParametersForInOperation(this IDbCommand command, } // DbType - var dbType = clientTypeToDbTypeResolver.Resolve(valueType); + var dbType = (valueType != null ? clientTypeToDbTypeResolver.Resolve(valueType) : null); // Create command.Parameters.Add(CreateParameter(command, name, values[i], dbType)); @@ -528,10 +522,10 @@ private static void CreateParametersForBetweenOperation(this IDbCommand command, var leftValue = values[0]; var rightValue = values[1]; var leftValueType = leftValue?.GetType()?.GetUnderlyingType(); - var rightValueType = leftValue?.GetType()?.GetUnderlyingType(); + var rightValueType = rightValue?.GetType()?.GetUnderlyingType(); // Propertyhandler (Left) - var leftPropertyHandler = PropertyHandlerCache.Get(leftValueType); + var leftPropertyHandler = GetProperyHandler(null, leftValueType); var leftdefinition = InvokePropertyHandlerSetMethod(leftPropertyHandler, leftValue, null); if (leftdefinition != null) { @@ -540,8 +534,8 @@ private static void CreateParametersForBetweenOperation(this IDbCommand command, } // Propertyhandler (Right) - var rightPropertyHandler = PropertyHandlerCache.Get(leftValueType); - var rightDefinition = InvokePropertyHandlerSetMethod(rightPropertyHandler, leftValue, null); + var rightPropertyHandler = GetProperyHandler(null, rightValueType); + var rightDefinition = InvokePropertyHandlerSetMethod(rightPropertyHandler, rightValue, null); if (rightDefinition != null) { rightValueType = rightDefinition.ReturnType.GetUnderlyingType(); @@ -558,8 +552,8 @@ private static void CreateParametersForBetweenOperation(this IDbCommand command, } // DbType - var leftDbType = clientTypeToDbTypeResolver.Resolve(leftValueType); - var rightDbType = clientTypeToDbTypeResolver.Resolve(rightValueType); + var leftDbType = (leftValueType != null ? clientTypeToDbTypeResolver.Resolve(leftValueType) : null); + var rightDbType = (rightValueType != null ? clientTypeToDbTypeResolver.Resolve(rightValueType) : null); // Add command.Parameters.Add( @@ -695,6 +689,23 @@ private static object AutomaticConvertGuidToString(object value) return value?.ToString(); } + /// + /// + /// + /// + /// + /// + private static object GetProperyHandler(ClassProperty classProperty, + Type targetType) + { + var propertyHandler = classProperty?.GetPropertyHandler(); + if (propertyHandler == null && targetType != null) + { + propertyHandler = PropertyHandlerCache.Get(targetType); + } + return propertyHandler; + } + #endregion } } diff --git a/RepoDb.Core/RepoDb/Extensions/DbConnectionExtension.cs b/RepoDb.Core/RepoDb/Extensions/DbConnectionExtension.cs index 4558623f9..60265afa7 100644 --- a/RepoDb.Core/RepoDb/Extensions/DbConnectionExtension.cs +++ b/RepoDb.Core/RepoDb/Extensions/DbConnectionExtension.cs @@ -201,12 +201,12 @@ internal static IEnumerable ExecuteQueryInternal(this IDbConnection con { using (var reader = command.ExecuteReader()) { - var result = (IEnumerable)DataReader.ToEnumerable(reader, dbFields, connection.GetDbSetting()).AsList(); + var result = DataReader.ToEnumerable(reader, dbFields, connection.GetDbSetting()).AsList(); // Set Cache if (cacheKey != null) { - cache?.Add(cacheKey, result, cacheItemExpiration.GetValueOrDefault(), false); + cache?.Add(cacheKey, (IEnumerable)result, cacheItemExpiration.GetValueOrDefault(), false); } // Return @@ -324,7 +324,7 @@ internal static async Task> ExecuteQueryAsyncInternal(this { using (var reader = await command.ExecuteReaderAsync(cancellationToken)) { - var result = (await DataReader.ToEnumerableAsync(reader, dbFields, connection.GetDbSetting(), cancellationToken)).AsList(); + var result = await DataReader.ToEnumerableAsync(reader, dbFields, connection.GetDbSetting()); // Set Cache if (cacheKey != null) @@ -576,13 +576,12 @@ private static IEnumerable ExecuteQueryInternalForType(this ID { using (var reader = command.ExecuteReader()) { - var result = (IEnumerable)DataReader.ToEnumerable(reader, - dbFields, connection.GetDbSetting()).AsList(); + var result = DataReader.ToEnumerable(reader, dbFields, connection.GetDbSetting()).AsList(); // Set Cache if (cacheKey != null) { - cache?.Add(cacheKey, result, cacheItemExpiration.GetValueOrDefault(), false); + cache?.Add(cacheKey, (IEnumerable)result, cacheItemExpiration.GetValueOrDefault(), false); } // Return @@ -842,7 +841,8 @@ private static async Task> ExecuteQueryAsyncInternalForType { using (var reader = await command.ExecuteReaderAsync(cancellationToken)) { - var result = (await DataReader.ToEnumerableAsync(reader, dbFields, connection.GetDbSetting(), cancellationToken)).AsList(); + var result = await DataReader.ToEnumerableAsync(reader, dbFields, + connection.GetDbSetting()); // Set Cache if (cacheKey != null) @@ -1692,13 +1692,27 @@ internal static ClassProperty GetAndGuardPrimaryKeyOrIdentityKey(IDbCon internal static ClassProperty GetAndGuardPrimaryKeyOrIdentityKey(IDbConnection connection, string tableName, IDbTransaction transaction) - where TEntity : class + where TEntity : class => + GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction, typeof(TEntity)); + + /// + /// + /// + /// + /// + /// + /// + /// + internal static ClassProperty GetAndGuardPrimaryKeyOrIdentityKey(IDbConnection connection, + string tableName, + IDbTransaction transaction, + Type entityType) { - var property = PrimaryCache.Get() ?? IdentityCache.Get(); + var dbFields = DbFieldCache.Get(connection, tableName, transaction); + var property = GetAndGuardPrimaryKeyOrIdentityKey(entityType, dbFields); if (property == null) { - var dbFields = DbFieldCache.Get(connection, tableName, transaction); - property = GetAndGuardPrimaryKeyOrIdentityKey(dbFields); + property = GetPrimaryOrIdentityKey(entityType); } return GetAndGuardPrimaryKeyOrIdentityKey(tableName, property); } @@ -1709,13 +1723,13 @@ internal static ClassProperty GetAndGuardPrimaryKeyOrIdentityKey(IDbCon /// /// /// - /// + /// /// internal static Task GetAndGuardPrimaryKeyOrIdentityKeyAsync(IDbConnection connection, IDbTransaction transaction, - CancellationToken cancellation = default) + CancellationToken cancellationToken = default) where TEntity : class => - GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, ClassMappedNameCache.Get(), transaction, cancellation); + GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, ClassMappedNameCache.Get(), transaction, cancellationToken); /// /// @@ -1726,17 +1740,33 @@ internal static Task GetAndGuardPrimaryKeyOrIdentityKeyAsync /// /// - internal static async Task GetAndGuardPrimaryKeyOrIdentityKeyAsync(IDbConnection connection, + internal static Task GetAndGuardPrimaryKeyOrIdentityKeyAsync(IDbConnection connection, string tableName, IDbTransaction transaction, CancellationToken cancellationToken = default) - where TEntity : class + where TEntity : class => + GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, typeof(TEntity), cancellationToken); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + internal static async Task GetAndGuardPrimaryKeyOrIdentityKeyAsync(IDbConnection connection, + string tableName, + IDbTransaction transaction, + Type entityType, + CancellationToken cancellationToken = default) { - var property = PrimaryCache.Get() ?? IdentityCache.Get(); + var dbFields = await DbFieldCache.GetAsync(connection, tableName, transaction, cancellationToken); + var property = GetAndGuardPrimaryKeyOrIdentityKey(entityType, dbFields); if (property == null) { - var dbFields = await DbFieldCache.GetAsync(connection, tableName, transaction, cancellationToken); - property = GetAndGuardPrimaryKeyOrIdentityKey(dbFields); + property = GetPrimaryOrIdentityKey(entityType); } return GetAndGuardPrimaryKeyOrIdentityKey(tableName, property); } @@ -1760,21 +1790,48 @@ internal static ClassProperty GetAndGuardPrimaryKeyOrIdentityKey(string tableNam /// /// /// - /// + /// /// /// - internal static ClassProperty GetAndGuardPrimaryKeyOrIdentityKey(IEnumerable dbFields) - where TEntity : class + internal static ClassProperty GetAndGuardPrimaryKeyOrIdentityKey(Type entityType, + IEnumerable dbFields) { - var dbField = dbFields?.FirstOrDefault(df => df.IsPrimary == true) ?? - dbFields?.FirstOrDefault(df => df.IsIdentity == true); - if (dbField != null) + if (entityType == null) + { + return null; + } + + // Properties + var properties = PropertyCache.Get(entityType) ?? entityType?.GetClassProperties(); + var key = (ClassProperty)null; + + // Primary + if (key == null) + { + var dbField = dbFields?.FirstOrDefault(df => df.IsPrimary == true); + key = properties?.FirstOrDefault(p => + string.Equals(p.GetMappedName(), dbField?.Name, StringComparison.OrdinalIgnoreCase)) ?? + PrimaryCache.Get(entityType); + } + + // Identity + if (key == null) + { + var dbField = dbFields?.FirstOrDefault(df => df.IsIdentity == true); + key = properties?.FirstOrDefault(p => + string.Equals(p.GetMappedName(), dbField?.Name, StringComparison.OrdinalIgnoreCase)) ?? + PrimaryCache.Get(entityType); + } + + // Return + if (key != null) + { + return key; + } + else { - var properties = PropertyCache.Get() ?? typeof(TEntity).GetClassProperties(); - return properties.FirstOrDefault(p => - string.Equals(p.GetMappedName(), dbField.Name, StringComparison.OrdinalIgnoreCase)); + throw new KeyFieldNotFoundException($"No primary key and identify found at type '{entityType.FullName}'."); } - throw new KeyFieldNotFoundException($"No primary key and identify found at type '{typeof(TEntity).Name}'."); } /// @@ -1852,8 +1909,17 @@ internal static QueryGroup WhatToQueryGroup(this IDbConnection connection, var queryGroup = WhatToQueryGroup(what); if (queryGroup == null) { - var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction); - queryGroup = WhatToQueryGroup(key, what); + var whatType = what?.GetType(); + if (whatType.IsClassType() || whatType.IsAnonymousType()) + { + var classProperty = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction, whatType); + queryGroup = WhatToQueryGroup(classProperty, what); + } + else + { + var dbField = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction); + queryGroup = WhatToQueryGroup(dbField, what); + } } return queryGroup; } @@ -1881,8 +1947,17 @@ internal static async Task WhatToQueryGroupAsync(this IDbConnecti var queryGroup = WhatToQueryGroup(what); if (queryGroup == null) { - var dbField = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, cancellationToken); - queryGroup = WhatToQueryGroup(dbField, what); + var whatType = what?.GetType(); + if (whatType.IsClassType() || whatType.IsAnonymousType()) + { + var classProperty = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, whatType, cancellationToken); + queryGroup = WhatToQueryGroup(classProperty, what); + } + else + { + var dbField = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, cancellationToken); + queryGroup = WhatToQueryGroup(dbField, what); + } } return queryGroup; } @@ -2171,10 +2246,18 @@ internal static QueryGroup ToQueryGroup(IEnumerable queryFields) #endregion /// - /// Throws an exception if the entities argument is null or empty. + /// /// - /// The type of the result. - /// The enumerable list of entity objects. + /// + /// + internal static ClassProperty GetPrimaryOrIdentityKey(Type entityType) => + entityType != null ? (PrimaryCache.Get(entityType) ?? IdentityCache.Get(entityType)) : null; + + /// + /// + /// + /// + /// internal static void ThrowIfNullOrEmpty(IEnumerable entities) where TEntity : class { diff --git a/RepoDb.Core/RepoDb/Operations/DbConnection/Merge.cs b/RepoDb.Core/RepoDb/Operations/DbConnection/Merge.cs index f1d32574c..79a334011 100644 --- a/RepoDb.Core/RepoDb/Operations/DbConnection/Merge.cs +++ b/RepoDb.Core/RepoDb/Operations/DbConnection/Merge.cs @@ -1915,7 +1915,8 @@ internal static TResult MergeInternalBase(this IDbConnection c // Check the qualifiers if (qualifiers?.Any() != true) { - var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction); + var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction, + entity?.GetType() ?? typeof(TEntity)); qualifiers = key.AsField().AsEnumerable(); } @@ -2177,7 +2178,8 @@ internal static async Task MergeAsyncInternalBase(thi // Check the qualifiers if (qualifiers?.Any() != true) { - var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, cancellationToken); + var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, + entity?.GetType() ?? typeof(TEntity), cancellationToken); qualifiers = key.AsField().AsEnumerable(); } diff --git a/RepoDb.Core/RepoDb/Operations/DbConnection/MergeAll.cs b/RepoDb.Core/RepoDb/Operations/DbConnection/MergeAll.cs index 0d025bede..776b55b65 100644 --- a/RepoDb.Core/RepoDb/Operations/DbConnection/MergeAll.cs +++ b/RepoDb.Core/RepoDb/Operations/DbConnection/MergeAll.cs @@ -377,7 +377,8 @@ internal static int MergeAllInternal(this IDbConnection connection, // Check the qualifiers if (qualifiers?.Any() != true) { - var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction); + var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction, + entities?.FirstOrDefault()?.GetType() ?? typeof(TEntity)); qualifiers = key.AsField().AsEnumerable(); } @@ -800,7 +801,8 @@ internal static async Task MergeAllAsyncInternal(this IDbConnectio // Check the qualifiers if (qualifiers?.Any() != true) { - var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, cancellationToken); + var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, + entities?.FirstOrDefault()?.GetType() ?? typeof(TEntity), cancellationToken); qualifiers = key.AsField().AsEnumerable(); } @@ -1359,7 +1361,7 @@ internal static int UpsertAllInternalBase(this IDbConnection connection where TEntity : class { // Variables needed - var type = entities?.First()?.GetType() ?? typeof(TEntity); + var type = entities?.FirstOrDefault()?.GetType() ?? typeof(TEntity); var isObjectType = typeof(TEntity) == StaticType.Object; var dbFields = DbFieldCache.Get(connection, tableName, transaction); var primary = dbFields?.FirstOrDefault(dbField => dbField.IsPrimary); @@ -1760,7 +1762,7 @@ internal static async Task UpsertAllAsyncInternalBase(this IDbConn where TEntity : class { // Variables needed - var type = entities?.First()?.GetType() ?? typeof(TEntity); + var type = entities?.FirstOrDefault()?.GetType() ?? typeof(TEntity); var isObjectType = typeof(TEntity) == StaticType.Object; var dbFields = await DbFieldCache.GetAsync(connection, tableName, transaction, cancellationToken); var primary = dbFields?.FirstOrDefault(dbField => dbField.IsPrimary); diff --git a/RepoDb.Core/RepoDb/Operations/DbConnection/Update.cs b/RepoDb.Core/RepoDb/Operations/DbConnection/Update.cs index c9d4ee67d..ad99f4769 100644 --- a/RepoDb.Core/RepoDb/Operations/DbConnection/Update.cs +++ b/RepoDb.Core/RepoDb/Operations/DbConnection/Update.cs @@ -45,7 +45,8 @@ public static int Update(this IDbConnection connection, IStatementBuilder statementBuilder = null) where TEntity : class { - var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction); + var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction, + entity?.GetType() ?? typeof(TEntity)); return UpdateInternal(connection: connection, tableName: tableName, entity: entity, @@ -622,7 +623,8 @@ public static async Task UpdateAsync(this IDbConnection connection CancellationToken cancellationToken = default) where TEntity : class { - var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, cancellationToken); + var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, + entity?.GetType() ?? typeof(TEntity), cancellationToken); return await UpdateAsyncInternal(connection: connection, tableName: tableName, entity: entity, @@ -1238,7 +1240,7 @@ public static int Update(this IDbConnection connection, ITrace trace = null, IStatementBuilder statementBuilder = null) { - var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction); + var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction, entity?.GetType()); return UpdateInternal(connection: connection, tableName: tableName, entity: entity, @@ -1428,7 +1430,8 @@ public static async Task UpdateAsync(this IDbConnection connection, IStatementBuilder statementBuilder = null, CancellationToken cancellationToken = default) { - var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, cancellationToken); + var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, + entity?.GetType(), cancellationToken); return await UpdateAsyncInternal(connection: connection, tableName: tableName, entity: entity, diff --git a/RepoDb.Core/RepoDb/Operations/DbConnection/UpdateAll.cs b/RepoDb.Core/RepoDb/Operations/DbConnection/UpdateAll.cs index fdcba341f..fdafbf4f4 100644 --- a/RepoDb.Core/RepoDb/Operations/DbConnection/UpdateAll.cs +++ b/RepoDb.Core/RepoDb/Operations/DbConnection/UpdateAll.cs @@ -376,7 +376,8 @@ internal static int UpdateAllInternal(this IDbConnection connection, { if (qualifiers?.Any() != true) { - var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction); + var key = GetAndGuardPrimaryKeyOrIdentityKey(connection, tableName, transaction, + entities?.FirstOrDefault()?.GetType() ?? typeof(TEntity)); qualifiers = key.AsField().AsEnumerable(); } return UpdateAllInternalBase(connection: connection, @@ -777,7 +778,8 @@ internal static async Task UpdateAllAsyncInternal(this IDbConnecti { if (qualifiers?.Any() != true) { - var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, cancellationToken); + var key = await GetAndGuardPrimaryKeyOrIdentityKeyAsync(connection, tableName, transaction, + entities?.FirstOrDefault()?.GetType() ?? typeof(TEntity), cancellationToken); qualifiers = key.AsField().AsEnumerable(); } return await UpdateAllAsyncInternalBase(connection: connection, diff --git a/RepoDb.Core/RepoDb/RepoDb.csproj b/RepoDb.Core/RepoDb/RepoDb.csproj index 926983ffb..dedd8c635 100644 --- a/RepoDb.Core/RepoDb/RepoDb.csproj +++ b/RepoDb.Core/RepoDb/RepoDb.csproj @@ -3,7 +3,7 @@ netstandard2.0 Michael Camara Pendon - 1.12.1 + 1.12.3 RepoDb RepoDb A hybrid ORM library for .NET. @@ -17,8 +17,8 @@ - 1.12.1.0 - 1.12.1.0 + 1.12.3.0 + 1.12.3.0 True False