Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better cast/convert support for Cosmos #35000

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/EFCore.Cosmos/Query/Internal/CosmosQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,20 @@ protected override Expression VisitObjectBinary(ObjectBinaryExpression objectBin
/// </summary>
protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpression)
{
if (sqlUnaryExpression.OperatorType == ExpressionType.Convert)
{
if (sqlUnaryExpression.TypeMapping?.ClrType == typeof(string))
{
_sqlBuilder.Append("ToString(");
Visit(sqlUnaryExpression.Operand);
_sqlBuilder.Append(")");
}
else
{
Visit(sqlUnaryExpression.Operand);
}
return sqlUnaryExpression;
}
var op = sqlUnaryExpression.OperatorType switch
{
ExpressionType.UnaryPlus => "+",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,12 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
/// </summary>
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
{
//Strip convert to object. Other converts should be fine as they will have a type mapping found but object won't

if (item is UnaryExpression { NodeType:ExpressionType.Convert} unaryExpression && unaryExpression.Type == typeof(object))
{
item = unaryExpression.Operand;
}
// Simplify x.Array.Contains[1] => ARRAY_CONTAINS(x.Array, 1) insert of IN+subquery
if (source.TryExtractArray(out var array, ignoreOrderings: true)
&& array is SqlExpression scalarArray // TODO: Contains over arrays of structural types, #34027
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ protected virtual void AddTranslationErrorDetails(string details)

if (result is SqlExpression translation)
{
if (translation is SqlUnaryExpression { OperatorType: ExpressionType.Convert } sqlUnaryExpression
&& sqlUnaryExpression.Type == typeof(object))
{
translation = sqlUnaryExpression.Operand;
}

if (applyDefaultTypeMapping)
{
translation = sqlExpressionFactory.ApplyDefaultTypeMapping(translation);
Expand Down Expand Up @@ -185,10 +191,6 @@ when TryRewriteEntityEquality(
equalsMethod: false,
out var result):
return result;

case { Method: var method } when method == ConcatMethodInfo:
return QueryCompilationContext.NotTranslatedExpression;

default:
var uncheckedNodeTypeVariant = binaryExpression.NodeType switch
{
Expand Down Expand Up @@ -625,7 +627,7 @@ when method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains):
arguments = new SqlExpression[methodCallExpression.Arguments.Count];
for (var i = 0; i < arguments.Length; i++)
{
var argument = methodCallExpression.Arguments[i];
var argument = RemoveObjectConvert(methodCallExpression.Arguments[i]);
if (TranslationFailed(argument, Visit(argument), out var sqlArgument))
{
return TranslateAsSubquery(methodCallExpression);
Expand Down Expand Up @@ -813,32 +815,46 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
return QueryCompilationContext.NotTranslatedExpression;
}

return unaryExpression.NodeType switch
switch (unaryExpression.NodeType)
{
ExpressionType.Not
=> sqlExpressionFactory.Not(sqlOperand!),

ExpressionType.Negate or ExpressionType.NegateChecked
=> sqlExpressionFactory.Negate(sqlOperand!),

// Convert nodes can be an explicit user gesture in the query, or they may get introduced by the compiler (e.g. when a Child is
// passed as an argument for a parameter of type Parent). The latter type should generally get stripped out as a pure C#/LINQ
// artifact that shouldn't affect translation, but the latter may be an indication from the user that they want to apply a
// type change.
ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory<float> (for vector search)
|| (unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit" }
&& IsReadOnlyMemory(unaryExpression.Type.UnwrapNullableType()))
|| unaryExpression.Type.UnwrapNullableType() == operand.Type
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum)
case ExpressionType.Not:
return sqlExpressionFactory.Not(sqlOperand!);

case ExpressionType.Negate:
case ExpressionType.NegateChecked:
return sqlExpressionFactory.Negate(sqlOperand!);

case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
case ExpressionType.TypeAs:
// Object convert needs to be converted to explicit cast when mismatching types
// But we let it pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
|| unaryExpression.Type == typeof(object)
=> sqlOperand!,
if (operand.Type.IsInterface
&& unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory<float> (for vector search)
|| (unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit" }
&& IsReadOnlyMemory(unaryExpression.Type.UnwrapNullableType()))
|| unaryExpression.Type.UnwrapNullableType() == operand.Type.UnwrapNullableType()
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum))
{
return sqlOperand!;
}

_ => QueryCompilationContext.NotTranslatedExpression
};
// Introduce explicit cast only if the target type is mapped else we need to client eval
if (unaryExpression.Type == typeof(object)
|| typeMappingSource.FindMapping(unaryExpression.Type, queryCompilationContext.Model) != null)
{
sqlOperand = sqlExpressionFactory.ApplyDefaultTypeMapping(sqlOperand);

return sqlExpressionFactory.Convert(sqlOperand!, unaryExpression.Type);
}

break;

case ExpressionType.Quote:
return operand;
}

return QueryCompilationContext.NotTranslatedExpression;

static bool IsReadOnlyMemory(Type type)
=> type is { IsGenericType: true, IsGenericTypeDefinition: false }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ public class SqlUnaryExpression : SqlExpression
{
private static readonly ISet<ExpressionType> AllowedOperators = new HashSet<ExpressionType>
{
ExpressionType.Equal,
ExpressionType.NotEqual,
ExpressionType.Convert,
ExpressionType.Not,
ExpressionType.Negate,
ExpressionType.OnesComplement,
ExpressionType.UnaryPlus
};

Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
case ExpressionType.Coalesce:
{
inferredTypeMapping = typeMapping ?? ExpressionExtensions.InferTypeMapping(left, right);
resultType = inferredTypeMapping?.ClrType ?? left.Type;
resultType = inferredTypeMapping?.ClrType ?? (left.Type != typeof(object) ? left.Type : right.Type);
resultTypeMapping = inferredTypeMapping;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,22 +438,33 @@ FROM root c
""");
});

public override async Task Sum_with_division_on_decimal(bool async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(async () => await base.Sum_with_division_on_decimal(async));
public override Task Sum_with_division_on_decimal(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Sum_with_division_on_decimal(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE SUM((c["Quantity"] / 2.09))
FROM root c
WHERE (c["$type"] = "OrderDetail")
""");
});

public override async Task Sum_with_division_on_decimal_no_significant_digits(bool async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(
async () => await base.Sum_with_division_on_decimal_no_significant_digits(async));
public override Task Sum_with_division_on_decimal_no_significant_digits(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Sum_with_division_on_decimal_no_significant_digits(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE SUM((c["Quantity"] / 2.0))
FROM root c
WHERE (c["$type"] = "OrderDetail")
""");
});

public override Task Sum_with_coalesce(bool async)
=> Fixture.NoSyncTest(
Expand Down Expand Up @@ -723,22 +734,33 @@ FROM root c
""");
});

public override async Task Average_with_division_on_decimal(bool async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(async () => await base.Average_with_division_on_decimal(async));
public override Task Average_with_division_on_decimal(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Average_with_division_on_decimal(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE AVG((c["Quantity"] / 2.09))
FROM root c
WHERE (c["$type"] = "OrderDetail")
""");
});

public override async Task Average_with_division_on_decimal_no_significant_digits(bool async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(
async () => await base.Average_with_division_on_decimal_no_significant_digits(async));
public override Task Average_with_division_on_decimal_no_significant_digits(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Average_with_division_on_decimal_no_significant_digits(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE AVG((c["Quantity"] / 2.0))
FROM root c
WHERE (c["$type"] = "OrderDetail")
""");
});

public override Task Average_with_coalesce(bool async)
=> Fixture.NoSyncTest(
Expand Down Expand Up @@ -2036,36 +2058,47 @@ public override async Task OfType_Select_OfType_Select(bool async)
AssertSql();
}

public override async Task Average_with_non_matching_types_in_projection_doesnt_produce_second_explicit_cast(bool async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(
async () => await base.Average_with_non_matching_types_in_projection_doesnt_produce_second_explicit_cast(async));
public override Task Average_with_non_matching_types_in_projection_doesnt_produce_second_explicit_cast(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Average_with_non_matching_types_in_projection_doesnt_produce_second_explicit_cast(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE AVG(c["OrderID"])
FROM root c
WHERE ((c["$type"] = "Order") AND STARTSWITH(c["CustomerID"], "A"))
""");
});

public override async Task Max_with_non_matching_types_in_projection_introduces_explicit_cast(bool async)
{
// Always throws for sync.
if (async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(
async () => await base.Max_with_non_matching_types_in_projection_introduces_explicit_cast(async));
public override Task Max_with_non_matching_types_in_projection_introduces_explicit_cast(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Max_with_non_matching_types_in_projection_introduces_explicit_cast(a);

AssertSql();
}
}
AssertSql(
"""
SELECT VALUE MAX(c["OrderID"])
FROM root c
WHERE ((c["$type"] = "Order") AND STARTSWITH(c["CustomerID"], "A"))
""");
});

public override async Task Min_with_non_matching_types_in_projection_introduces_explicit_cast(bool async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(
async () => await base.Min_with_non_matching_types_in_projection_introduces_explicit_cast(async));
public override Task Min_with_non_matching_types_in_projection_introduces_explicit_cast(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Min_with_non_matching_types_in_projection_introduces_explicit_cast(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE MIN(c["OrderID"])
FROM root c
WHERE ((c["$type"] = "Order") AND STARTSWITH(c["CustomerID"], "A"))
""");
});

public override async Task OrderBy_Take_Last_gives_correct_result(bool async)
{
Expand Down Expand Up @@ -2488,13 +2521,19 @@ public override async Task Collection_LastOrDefault_member_access_in_projection_
AssertSql();
}

public override async Task Sum_over_explicit_cast_over_column(bool async)
{
// Aggregate selecting non-mapped type. Issue #20677.
await Assert.ThrowsAsync<KeyNotFoundException>(async () => await base.Sum_over_explicit_cast_over_column(async));
public override Task Sum_over_explicit_cast_over_column(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Sum_over_explicit_cast_over_column(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE SUM(c["OrderID"])
FROM root c
WHERE (c["$type"] = "Order")
""");
});

public override async Task Contains_over_scalar_with_null_should_rewrite_to_identity_equality_subquery(bool async)
{
Expand Down Expand Up @@ -2778,13 +2817,19 @@ OFFSET 0 LIMIT 1
""");
});

[ConditionalTheory(Skip = "Issue #20677")]
public override async Task Type_casting_inside_sum(bool async)
{
await base.Type_casting_inside_sum(async);
public override Task Type_casting_inside_sum(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Type_casting_inside_sum(a);

AssertSql();
}
AssertSql(
"""
SELECT VALUE SUM(c["Discount"])
FROM root c
WHERE (c["$type"] = "OrderDetail")
""");
});

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
Expand Down
Loading