Skip to content

Refactor XmlDocProvider #7904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.TypeSpec.Generator.Expressions;
using Microsoft.TypeSpec.Generator.Input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
using Microsoft.TypeSpec.Generator.Input;
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Tests.Common;
using NUnit.Framework;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading;
using System.Threading.Tasks;

namespace Sample
{
/// <summary> TestClient description. </summary>
public partial class TestClient
{
/// <summary>
/// [Protocol Method] Operation description
/// <list type="bullet">
/// <item>
/// <description> This <see href="https://aka.ms/azsdk/net/protocol-methods">protocol method</see> allows explicit creation of the request and processing of the response for advanced scenarios. </description>
/// </item>
/// </list>
/// </summary>
/// <param name="queryParam_modified"> queryParam description. </param>
/// <param name="options"> The request options, which can override default behaviors of the client pipeline on a per-call basis. </param>
/// <exception cref="global::System.ArgumentNullException"> <paramref name="queryParam_modified"/> is null. </exception>
/// <exception cref="global::System.ArgumentException"> <paramref name="queryParam_modified"/> is an empty string, and was expected to be non-empty. </exception>
/// <exception cref="global::System.ClientModel.ClientResultException"> Service returned a non-success status code. </exception>
/// <returns> The response returned from the service. </returns>
public virtual global::System.ClientModel.ClientResult Operation(string queryParam_modified, global::System.ClientModel.Primitives.RequestOptions options)
{
global::Sample.Argument.AssertNotNullOrEmpty(queryParam_modified, nameof(queryParam_modified));

using global::System.ClientModel.Primitives.PipelineMessage message = this.CreateOperationRequest(queryParam_modified, options);
return global::System.ClientModel.ClientResult.FromResponse(Pipeline.ProcessMessage(message, options));
}

/// <summary>
/// [Protocol Method] Operation description
/// <list type="bullet">
/// <item>
/// <description> This <see href="https://aka.ms/azsdk/net/protocol-methods">protocol method</see> allows explicit creation of the request and processing of the response for advanced scenarios. </description>
/// </item>
/// </list>
/// </summary>
/// <param name="queryParam_modified"> queryParam description. </param>
/// <param name="options"> The request options, which can override default behaviors of the client pipeline on a per-call basis. </param>
/// <exception cref="global::System.ArgumentNullException"> <paramref name="queryParam_modified"/> is null. </exception>
/// <exception cref="global::System.ArgumentException"> <paramref name="queryParam_modified"/> is an empty string, and was expected to be non-empty. </exception>
/// <exception cref="global::System.ClientModel.ClientResultException"> Service returned a non-success status code. </exception>
/// <returns> The response returned from the service. </returns>
public virtual async global::System.Threading.Tasks.Task<global::System.ClientModel.ClientResult> OperationAsync(string queryParam_modified, global::System.ClientModel.Primitives.RequestOptions options)
{
global::Sample.Argument.AssertNotNullOrEmpty(queryParam_modified, nameof(queryParam_modified));

using global::System.ClientModel.Primitives.PipelineMessage message = this.CreateOperationRequest(queryParam_modified, options);
return global::System.ClientModel.ClientResult.FromResponse(await Pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false));
}

/// <summary> Operation description. </summary>
/// <param name="cancellationToken"> The cancellation token that can be used to cancel the operation. </param>
/// <exception cref="global::System.ClientModel.ClientResultException"> Service returned a non-success status code. </exception>
public virtual global::System.ClientModel.ClientResult Operation(global::System.Threading.CancellationToken cancellationToken = default)
{
return this.Operation(cancellationToken.CanBeCanceled ? new global::System.ClientModel.Primitives.RequestOptions { CancellationToken = cancellationToken } : null);
}

/// <summary> Operation description. </summary>
/// <param name="cancellationToken"> The cancellation token that can be used to cancel the operation. </param>
/// <exception cref="global::System.ClientModel.ClientResultException"> Service returned a non-success status code. </exception>
public virtual async global::System.Threading.Tasks.Task<global::System.ClientModel.ClientResult> OperationAsync(global::System.Threading.CancellationToken cancellationToken = default)
{
return await this.OperationAsync(cancellationToken.CanBeCanceled ? new global::System.ClientModel.Primitives.RequestOptions { CancellationToken = cancellationToken } : null).ConfigureAwait(false);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Linq;
using Microsoft.TypeSpec.Generator.ClientModel;
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
using Microsoft.TypeSpec.Generator.ClientModel.Tests;
using Microsoft.TypeSpec.Generator.Input;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Tests.Common;
using NUnit.Framework;

namespace Microsoft.TypeSpec.Generator.Tests.Providers
{
public class XmlDocProviderTests
{
private const string TestClientName = "TestClient";
private static readonly InputClient _testClient = InputFactory.Client(TestClientName,
methods: [
InputFactory.BasicServiceMethod(
"Operation",
InputFactory.Operation(
"Operation",
parameters:
[
InputFactory.Parameter(
"queryParam",
InputPrimitiveType.String,
isRequired: true,
location: InputRequestLocation.Query)
]))]);

[Test]
public void ValidateXmlDocShouldChangeFromVisitors()
{
MockHelpers.LoadMockGenerator(
createClientCore: inputClient => new MockClientProvider(inputClient),
clients: () => [_testClient],
includeXmlDocs: true
);
var testVisitor = new TestVisitor();
ScmCodeModelGenerator.Instance.AddVisitor(testVisitor);

// visit the library
testVisitor.DoVisitLibrary(CodeModelGenerator.Instance.OutputLibrary);

// check if the parameter names in xml docs are changed accordingly
// find the client in outputlibrary
var client = ScmCodeModelGenerator.Instance.OutputLibrary.TypeProviders.OfType<ClientProvider>().FirstOrDefault()!;
Assert.IsNotNull(client);
var writer = ScmCodeModelGenerator.Instance.GetWriter(client);
var file = writer.Write();

Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

private class TestVisitor : ScmLibraryVisitor
{
public void DoVisitLibrary(OutputLibrary library)
{
VisitLibrary(library);
}

protected internal override ScmMethodProvider? VisitMethod(ScmMethodProvider method)
{
// modify the parameter names in-place
foreach (var parameter in method.Signature.Parameters)
{
if (parameter.Name == "queryParam")
{
// modify the parameter name
parameter.Update(name: "queryParam_modified");
}
}
return method;
}
}

private class MockClientProvider : ClientProvider
{
public MockClientProvider(InputClient client) : base(client)
{ }

// ignore all the ctors to make the output more clear
protected override ConstructorProvider[] BuildConstructors() => [];

// ignore all the fields to make the output more clear
protected override FieldProvider[] BuildFields() => [];

protected override PropertyProvider[] BuildProperties() => [];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.TypeSpec.Generator.Expressions;
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Statements;
Expand All @@ -24,6 +26,7 @@ namespace Microsoft.TypeSpec.Generator.Primitives
/// <param name="GenericParameterConstraints">The generic parameter constraints of the method.</param>
/// <param name="ExplicitInterface">The explicit interface of the method.</param>
/// <param name="NonDocumentComment">The non-document comment of the method.</param>
[DebuggerDisplay("{GetDebuggerDisplay(),nq}")]
public sealed class MethodSignature(string Name, FormattableString? Description, MethodSignatureModifiers Modifiers, CSharpType? ReturnType, FormattableString? ReturnDescription, IReadOnlyList<ParameterProvider> Parameters, IReadOnlyList<AttributeStatement>? Attributes = null, IReadOnlyList<CSharpType>? GenericArguments = null, IReadOnlyList<WhereExpression>? GenericParameterConstraints = null, CSharpType? ExplicitInterface = null, string? NonDocumentComment = null)
: MethodSignatureBase(Name, Description, NonDocumentComment, Modifiers, Parameters, Attributes ?? Array.Empty<AttributeStatement>(), ReturnType)
{
Expand Down Expand Up @@ -85,5 +88,10 @@ public int GetHashCode([DisallowNull] MethodSignature obj)
return HashCode.Combine(obj.Name, obj.ReturnType);
}
}

private string GetDebuggerDisplay()
{
return $"{ReturnType?.FullyQualifiedName ?? "void"} {Name}({string.Join(", ", Parameters.Select(p => $"{p.Type.FullyQualifiedName} {p.Name}"))})";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,9 @@ protected override MethodProvider[] BuildMethods()
$"A new {modelProvider.Type:C} instance for mocking.",
GetParameters(modelProvider, fullConstructor));

var parameters = new List<XmlDocParamStatement>(signature.Parameters.Count);
foreach (var param in signature.Parameters)
{
parameters.Add(new XmlDocParamStatement(param));
}

var docs = new XmlDocProvider(
modelProvider.XmlDocs.Summary,
parameters,
signature.Parameters,
returns: new XmlDocReturnsStatement($"A new {modelProvider.Type:C} instance for mocking."));

MethodBodyStatement statements = ConstructMethodBody(signature, typeToInstantiate);
Expand Down Expand Up @@ -271,7 +265,7 @@ private static bool TryBuildMethodArgumentsForOverload(
return false;
}

var currentParameters = currentMethod.Parameters.ToHashSet();
var currentParameters = currentMethod.Parameters.ToHashSet(ParameterProvider.EqualityByNameAndType);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since in this context, we are trying to compare the parameters by their names and types, we used a comparer here to do that, and its default behavior in a dictionary is by reference equality.

foreach (var parameter in previousMethod.Parameters)
{
if (!currentParameters.Contains(parameter))
Expand All @@ -281,7 +275,7 @@ private static bool TryBuildMethodArgumentsForOverload(
}

// Build the arguments for the overload
var previousParameters = previousMethod.Parameters.ToHashSet();
var previousParameters = previousMethod.Parameters.ToHashSet(ParameterProvider.EqualityByNameAndType);
List<ValueExpression> arguments = new(currentMethodParameterCount);

foreach (var parameter in currentMethod.Parameters)
Expand Down Expand Up @@ -470,7 +464,7 @@ private static bool ContainsSameParameters(MethodSignature method1, MethodSignat
return false;
}

HashSet<ParameterProvider> method1Parameters = [.. method1.Parameters];
HashSet<ParameterProvider> method1Parameters = method1.Parameters.ToHashSet(ParameterProvider.EqualityByNameAndType);
foreach (var method2Param in method2.Parameters)
{
if (!method1Parameters.Contains(method2Param))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using Microsoft.TypeSpec.Generator.Expressions;
using Microsoft.TypeSpec.Generator.Input;
using Microsoft.TypeSpec.Generator.Input.Extensions;
Expand Down Expand Up @@ -158,13 +159,7 @@ public bool Equals(ParameterProvider? y)

public override int GetHashCode()
{
return GetHashCode(this);
}

private int GetHashCode([DisallowNull] ParameterProvider obj)
{
// remove type as part of the hash code generation as the type might have changes between versions
return HashCode.Combine(obj.Name);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i change this because the Name property in this class is not immutable - therefore if we use this type as key of dictionary or in hashset in a scenario that its name would change, weird behavior would happen.
For instance, if we do this:

var set = new HashSet<ParameterProvider>();
set.Add(parameter);
parameter.Update(name: "bar"); // change the name to bar
Console.WriteLine(set.Contains(parameter); // this will return false.

return RuntimeHelpers.GetHashCode(this); // gets the hash code based on object reference
}

private string GetDebuggerDisplay()
Expand Down Expand Up @@ -333,5 +328,27 @@ public void Update(
Validation = validation.Value;
}
}

internal static IEqualityComparer<ParameterProvider> EqualityByNameAndType = new ParameterProviderEqualityComparer();

private struct ParameterProviderEqualityComparer : IEqualityComparer<ParameterProvider>
{
public bool Equals(ParameterProvider? x, ParameterProvider? y)
{
if (ReferenceEquals(x, y))
{
return true;
}
if (x is null || y is null)
{
return false;
}
return x.Equals(y);
}
public int GetHashCode([DisallowNull] ParameterProvider obj)
{
return HashCode.Combine(obj.Name);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,43 @@
// Licensed under the MIT License.

using System.Collections.Generic;
using System.Linq;
using Microsoft.TypeSpec.Generator.Statements;

namespace Microsoft.TypeSpec.Generator.Providers
{
public class XmlDocProvider
{
public static XmlDocProvider Empty => new XmlDocProvider();
public static XmlDocProvider Empty { get; } = new XmlDocProvider();
public static XmlDocProvider InheritDocs { get; } = new XmlDocProvider { Inherit = new XmlDocInheritStatement() };

private IReadOnlyList<ParameterProvider>? _parameters;

public XmlDocProvider(
XmlDocSummaryStatement? summary = null,
IReadOnlyList<XmlDocParamStatement>? parameters = null,
IReadOnlyList<ParameterProvider>? parameters = null,
IReadOnlyList<XmlDocExceptionStatement>? exceptions = null,
XmlDocReturnsStatement? returns = null,
XmlDocInheritStatement? inherit = null)
{
Summary = summary;
Parameters = parameters ?? new List<XmlDocParamStatement>();
_parameters = parameters;
Exceptions = exceptions ?? new List<XmlDocExceptionStatement>();
Returns = returns;
Inherit = inherit;
}

private static XmlDocProvider? _inheritDocs;

public static XmlDocProvider InheritDocs =>
_inheritDocs ??= new XmlDocProvider { Inherit = new XmlDocInheritStatement() };

public XmlDocSummaryStatement? Summary { get; private set; }
public IReadOnlyList<XmlDocParamStatement> Parameters { get; private set; }

private IReadOnlyList<XmlDocParamStatement>? _parameterStatements;
public IReadOnlyList<XmlDocParamStatement> Parameters => _parameterStatements ??= _parameters?.Select(p => new XmlDocParamStatement(p)).ToArray() ?? [];
public XmlDocReturnsStatement? Returns { get; private set; }
public IReadOnlyList<XmlDocExceptionStatement> Exceptions { get; private set; }
public XmlDocInheritStatement? Inherit { get; private set; }

public void Update(
XmlDocSummaryStatement? summary = null,
IReadOnlyList<XmlDocParamStatement>? parameters = null,
IReadOnlyList<ParameterProvider>? parameters = null,
IReadOnlyList<XmlDocExceptionStatement>? exceptions = null,
XmlDocReturnsStatement? returns = null,
XmlDocInheritStatement? inherit = null)
Expand All @@ -49,7 +50,8 @@ public void Update(

if (parameters != null)
{
Parameters = parameters;
_parameters = parameters;
_parameterStatements = null; // Reset the cached parameter statements
}

if (exceptions != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ public class XmlDocExceptionStatement : MethodBodyStatement
public Type ExceptionType { get; }
public IReadOnlyList<ParameterProvider> Parameters { get; }

private readonly string _reason;
private string? _reason;

public XmlDocExceptionStatement(Type exceptionType, IReadOnlyList<ParameterProvider> parameters)
{
ExceptionType = exceptionType;
Parameters = parameters;
_reason = GetText(exceptionType);
}

public XmlDocExceptionStatement(Type exceptionType, string reason, IReadOnlyList<ParameterProvider> parameters)
Expand Down Expand Up @@ -54,7 +53,7 @@ internal override void Write(CodeWriter writer)
writer.Append($" or <paramref name=\"{Parameters[Parameters.Count - 1].AsExpression().Declaration}\"/>");
}

writer.WriteLine($" {_reason} </exception>");
writer.WriteLine($" {_reason ?? GetText(ExceptionType)} </exception>");
}
}
}
Loading
Loading