Skip to content

Commit 7421b99

Browse files
committed
CSHARP-5749: Support C# 14 changes that result in overloads now binding MemoryExtensions extension methods
1 parent c979cd5 commit 7421b99

File tree

14 files changed

+677
-18
lines changed

14 files changed

+677
-18
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private AstStage RenderProjectStage(
6363
ExpressionTranslationOptions translationOptions,
6464
out IBsonSerializer<TOutput> outputSerializer)
6565
{
66-
var partiallyEvaluatedOutput = (Expression<Func<TGrouping, TOutput>>)PartialEvaluator.EvaluatePartially(_output);
66+
var partiallyEvaluatedOutput = (Expression<Func<TGrouping, TOutput>>)LinqExpressionPreprocessor.Preprocess(_output);
6767
var context = TranslationContext.Create(translationOptions);
6868
var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true);
6969
var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation);
@@ -105,7 +105,7 @@ protected override AstStage RenderGroupingStage(
105105
ExpressionTranslationOptions translationOptions,
106106
out IBsonSerializer<IGrouping<TValue, TInput>> groupingOutputSerializer)
107107
{
108-
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
108+
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)LinqExpressionPreprocessor.Preprocess(_groupBy);
109109
var context = TranslationContext.Create(translationOptions);
110110
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);
111111

@@ -149,7 +149,7 @@ protected override AstStage RenderGroupingStage(
149149
ExpressionTranslationOptions translationOptions,
150150
out IBsonSerializer<IGrouping<AggregateBucketAutoResultId<TValue>, TInput>> groupingOutputSerializer)
151151
{
152-
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
152+
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)LinqExpressionPreprocessor.Preprocess(_groupBy);
153153
var context = TranslationContext.Create(translationOptions);
154154
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);
155155

@@ -187,7 +187,7 @@ protected override AstStage RenderGroupingStage(
187187
ExpressionTranslationOptions translationOptions,
188188
out IBsonSerializer<IGrouping<TValue, TInput>> groupingOutputSerializer)
189189
{
190-
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
190+
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)LinqExpressionPreprocessor.Preprocess(_groupBy);
191191
var context = TranslationContext.Create(translationOptions);
192192
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);
193193
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System;
17+
using System.Collections.ObjectModel;
18+
using System.Linq;
19+
using System.Linq.Expressions;
20+
using System.Reflection;
21+
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
22+
23+
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc;
24+
25+
/// <summary>
26+
/// This visitor rewrites expressions where new features of .NET CLR or
27+
/// C# compiler interfere with LINQ expression tree translation.
28+
/// </summary>
29+
internal class ClrCompatExpressionRewriter : ExpressionVisitor
30+
{
31+
private static readonly ClrCompatExpressionRewriter __instance = new();
32+
33+
public static Expression Rewrite(Expression expression)
34+
=> __instance.Visit(expression);
35+
36+
/// <inheritdoc />
37+
protected override Expression VisitMethodCall(MethodCallExpression node)
38+
{
39+
node = (MethodCallExpression)base.VisitMethodCall(node);
40+
41+
var method = node.Method;
42+
var arguments = node.Arguments;
43+
44+
return method.Name switch
45+
{
46+
"Contains" => VisitContainsMethod(node, method, arguments),
47+
"SequenceEqual" => VisitSequenceEqualMethod(node, method, arguments),
48+
_ => node
49+
};
50+
51+
static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection<Expression> arguments)
52+
{
53+
if (method.IsOneOf(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue, MemoryExtensionsMethod.ContainsWithSpanAndValue))
54+
{
55+
var itemType = method.GetGenericArguments().Single();
56+
var span = arguments[0];
57+
var value = arguments[1];
58+
59+
if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
60+
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType))
61+
{
62+
return
63+
Expression.Call(
64+
EnumerableMethod.Contains.MakeGenericMethod(itemType),
65+
[unwrappedSpan, value]);
66+
}
67+
}
68+
else if (method.Is(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer))
69+
{
70+
var itemType = method.GetGenericArguments().Single();
71+
var span = arguments[0];
72+
var value = arguments[1];
73+
var comparer = arguments[2];
74+
75+
if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
76+
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType))
77+
{
78+
return
79+
Expression.Call(
80+
EnumerableMethod.ContainsWithComparer.MakeGenericMethod(itemType),
81+
[unwrappedSpan, value, comparer]);
82+
}
83+
}
84+
85+
return node;
86+
}
87+
88+
static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection<Expression> arguments)
89+
{
90+
if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpan))
91+
{
92+
var itemType = method.GetGenericArguments().Single();
93+
var span = arguments[0];
94+
var other = arguments[1];
95+
96+
if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
97+
TryUnwrapSpanImplicitCast(other, out var unwrappedOther) &&
98+
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) &&
99+
unwrappedOther.Type.ImplementsIEnumerableOf(itemType))
100+
{
101+
return
102+
Expression.Call(
103+
EnumerableMethod.SequenceEqual.MakeGenericMethod(itemType),
104+
[unwrappedSpan, unwrappedOther]);
105+
}
106+
}
107+
else if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpanAndComparer))
108+
{
109+
var itemType = method.GetGenericArguments().Single();
110+
var span = arguments[0];
111+
var other = arguments[1];
112+
var comparer = arguments[2];
113+
114+
if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
115+
TryUnwrapSpanImplicitCast(other, out var unwrappedOther) &&
116+
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) &&
117+
unwrappedOther.Type.ImplementsIEnumerableOf(itemType))
118+
{
119+
return
120+
Expression.Call(
121+
EnumerableMethod.SequenceEqualWithComparer.MakeGenericMethod(itemType),
122+
[unwrappedSpan, unwrappedOther, comparer]);
123+
}
124+
}
125+
126+
return node;
127+
}
128+
129+
// Erase implicit casts to ReadOnlySpan<T> and Span<T>
130+
static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result)
131+
{
132+
if (expression is MethodCallExpression
133+
{
134+
Method:
135+
{
136+
Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType
137+
}
138+
} methodCallExpression
139+
&& implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition
140+
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)))
141+
{
142+
result = methodCallExpression.Arguments[0];
143+
return true;
144+
}
145+
146+
result = null;
147+
return false;
148+
}
149+
}
150+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq.Expressions;
17+
18+
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc;
19+
20+
/// <summary>
21+
/// This class is called before we process any LINQ expression trees
22+
/// to perform any necessary pre-processing such as CLR compatibility
23+
/// and partial evaluation.
24+
/// </summary>
25+
internal static class LinqExpressionPreprocessor
26+
{
27+
public static Expression Preprocess(Expression expression)
28+
{
29+
expression = ClrCompatExpressionRewriter.Rewrite(expression);
30+
expression = PartialEvaluator.EvaluatePartially(expression);
31+
return expression;
32+
}
33+
}

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,59 @@
1313
* limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Reflection;
1718

1819
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
1920
{
2021
internal static class MethodInfoExtensions
2122
{
23+
public static bool Has1GenericArgument(this MethodInfo method, out Type genericArgument)
24+
{
25+
if (method.IsGenericMethod &&
26+
method.GetGenericArguments() is var genericArguments &&
27+
genericArguments.Length == 1)
28+
{
29+
genericArgument = genericArguments[0];
30+
return true;
31+
}
32+
33+
genericArgument = null;
34+
return false;
35+
}
36+
37+
public static bool Has2Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2)
38+
{
39+
if (method.GetParameters() is var parameters &&
40+
parameters.Length == 2)
41+
{
42+
parameter1 = parameters[0];
43+
parameter2 = parameters[1];
44+
return true;
45+
}
46+
47+
parameter1 = null;
48+
parameter2 = null;
49+
return false;
50+
}
51+
52+
public static bool Has3Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2, out ParameterInfo parameter3)
53+
{
54+
if (method.GetParameters() is var parameters &&
55+
parameters.Length == 3)
56+
{
57+
parameter1 = parameters[0];
58+
parameter2 = parameters[1];
59+
parameter3 = parameters[2];
60+
return true;
61+
}
62+
63+
parameter1 = null;
64+
parameter2 = null;
65+
parameter3 = null;
66+
return false;
67+
}
68+
2269
public static bool Is(this MethodInfo method, MethodInfo comparand)
2370
{
2471
if (comparand != null)

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,27 @@ public static bool IsNullableOf(this Type type, Type valueType)
248248
return type.IsNullable(out var nullableValueType) && nullableValueType == valueType;
249249
}
250250

251+
public static bool IsReadOnlySpanOf(this Type type, Type itemType)
252+
{
253+
return
254+
type.IsGenericType &&
255+
type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) &&
256+
type.GetGenericArguments()[0] == itemType;
257+
}
258+
251259
public static bool IsSameAsOrNullableOf(this Type type, Type valueType)
252260
{
253261
return type == valueType || type.IsNullableOf(valueType);
254262
}
255263

264+
public static bool IsSpanOf(this Type type, Type itemType)
265+
{
266+
return
267+
type.IsGenericType &&
268+
type.GetGenericTypeDefinition() == typeof(Span<>) &&
269+
type.GetGenericArguments()[0] == itemType;
270+
}
271+
256272
public static bool IsSubclassOfOrImplements(this Type type, Type baseTypeOrInterface)
257273
{
258274
return

src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ internal static class EnumerableMethod
5959
private static readonly MethodInfo __cast;
6060
private static readonly MethodInfo __concat;
6161
private static readonly MethodInfo __contains;
62+
private static readonly MethodInfo __containsWithComparer;
6263
private static readonly MethodInfo __count;
6364
private static readonly MethodInfo __countWithPredicate;
6465
private static readonly MethodInfo __defaultIfEmpty;
@@ -150,6 +151,7 @@ internal static class EnumerableMethod
150151
private static readonly MethodInfo __selectManyWithSelectorTakingIndex;
151152
private static readonly MethodInfo __selectWithSelectorTakingIndex;
152153
private static readonly MethodInfo __sequenceEqual;
154+
private static readonly MethodInfo __sequenceEqualWithComparer;
153155
private static readonly MethodInfo __single;
154156
private static readonly MethodInfo __singleOrDefault;
155157
private static readonly MethodInfo __singleOrDefaultWithPredicate;
@@ -226,6 +228,7 @@ static EnumerableMethod()
226228
__cast = ReflectionInfo.Method((IEnumerable source) => source.Cast<object>());
227229
__concat = ReflectionInfo.Method((IEnumerable<object> first, IEnumerable<object> second) => first.Concat(second));
228230
__contains = ReflectionInfo.Method((IEnumerable<object> source, object value) => source.Contains(value));
231+
__containsWithComparer = ReflectionInfo.Method((IEnumerable<object> source, object value, IEqualityComparer<object> comparer) => source.Contains(value, comparer));
229232
__count = ReflectionInfo.Method((IEnumerable<object> source) => source.Count());
230233
__countWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.Count(predicate));
231234
__defaultIfEmpty = ReflectionInfo.Method((IEnumerable<object> source) => source.DefaultIfEmpty());
@@ -317,6 +320,7 @@ static EnumerableMethod()
317320
__selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable<object> source, Func<object, int, IEnumerable<object>> selector) => source.SelectMany(selector));
318321
__selectWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable<object> source, Func<object, int, object> selector) => source.Select(selector));
319322
__sequenceEqual = ReflectionInfo.Method((IEnumerable<object> first, IEnumerable<object> second) => first.SequenceEqual(second));
323+
__sequenceEqualWithComparer = ReflectionInfo.Method((IEnumerable<object> first, IEnumerable<object> second, IEqualityComparer<object> comparer) => first.SequenceEqual(second, comparer));
320324
__single = ReflectionInfo.Method((IEnumerable<object> source) => source.Single());
321325
__singleOrDefault = ReflectionInfo.Method((IEnumerable<object> source) => source.SingleOrDefault());
322326
__singleOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.SingleOrDefault(predicate));
@@ -392,6 +396,7 @@ static EnumerableMethod()
392396
public static MethodInfo Cast => __cast;
393397
public static MethodInfo Concat => __concat;
394398
public static MethodInfo Contains => __contains;
399+
public static MethodInfo ContainsWithComparer => __containsWithComparer;
395400
public static MethodInfo Count => __count;
396401
public static MethodInfo CountWithPredicate => __countWithPredicate;
397402
public static MethodInfo DefaultIfEmpty => __defaultIfEmpty;
@@ -483,6 +488,7 @@ static EnumerableMethod()
483488
public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex;
484489
public static MethodInfo SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex;
485490
public static MethodInfo SequenceEqual => __sequenceEqual;
491+
public static MethodInfo SequenceEqualWithComparer => __sequenceEqualWithComparer;
486492
public static MethodInfo Single => __single;
487493
public static MethodInfo SingleOrDefault => __singleOrDefault;
488494
public static MethodInfo SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate;

0 commit comments

Comments
 (0)