From eefa3c9095c0e662cd637e7e838a2ed1c9131cb0 Mon Sep 17 00:00:00 2001 From: SwimmingRieux <141845753+SwimmingRieux@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:34:04 +0330 Subject: [PATCH] feat(Sanitization middleware): add sanetization middleware (#20) * add sanetization middleware * complete sanetizing middleware --- .../Middlewares/SanitizationMiddleware.cs | 82 +++++++++++++++++++ RelationshipAnalysis/Program.cs | 3 +- .../RelationshipAnalysis.csproj | 1 + 3 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 RelationshipAnalysis/Middlewares/SanitizationMiddleware.cs diff --git a/RelationshipAnalysis/Middlewares/SanitizationMiddleware.cs b/RelationshipAnalysis/Middlewares/SanitizationMiddleware.cs new file mode 100644 index 0000000..36a2a33 --- /dev/null +++ b/RelationshipAnalysis/Middlewares/SanitizationMiddleware.cs @@ -0,0 +1,82 @@ +using System.Text; +using Ganss.Xss; +using Microsoft.AspNetCore.Mvc.Controllers; +using Newtonsoft.Json; + +namespace RelationshipAnalysis.Middlewares; +public class SanitizationMiddleware +{ + private readonly RequestDelegate _next; + private readonly HtmlSanitizer _sanitizer; + + public SanitizationMiddleware(RequestDelegate next) + { + _next = next; + _sanitizer = new HtmlSanitizer(); + } + + public async Task InvokeAsync(HttpContext context) + { + if (context.Request.ContentType != null && context.Request.ContentType.Contains("application/json")) + { + context.Request.EnableBuffering(); + var body = await new StreamReader(context.Request.Body).ReadToEndAsync(); + context.Request.Body.Position = 0; + + var type = GetRequestDtoType(context); + if (type != null) + { + object sanitizedDto; + if (type == typeof(List)) + { + var dto = JsonConvert.DeserializeObject>(body); + sanitizedDto = SanitizeEnumerable(dto); + } + else + { + var dto = JsonConvert.DeserializeObject(body, type); + sanitizedDto = SanitizeDto(dto); + } + + var sanitizedBody = JsonConvert.SerializeObject(sanitizedDto); + var buffer = Encoding.UTF8.GetBytes(sanitizedBody); + context.Request.Body = new MemoryStream(buffer); + } + } + + await _next(context); + } + + private Type GetRequestDtoType(HttpContext context) + { + var endpoint = context.GetEndpoint(); + var actionDescriptor = endpoint?.Metadata.GetMetadata(); + if (actionDescriptor != null) + { + var parameters = actionDescriptor.Parameters; + var dtoParameter = parameters.FirstOrDefault(p => p.ParameterType.IsClass && p.ParameterType != typeof(string)); + return dtoParameter?.ParameterType; + } + return null; + } + + private IEnumerable SanitizeEnumerable(IEnumerable dto) + { + return dto.Select(str => _sanitizer.Sanitize(str)); + } + private object SanitizeDto(object dto) + { + var properties = dto.GetType().GetProperties().Where(p => p.PropertyType == typeof(string) && p.CanWrite && p.CanRead); + + foreach (var property in properties) + { + var value = (string)property.GetValue(dto); + if (value != null) + { + property.SetValue(dto, _sanitizer.Sanitize(value)); + } + } + + return dto; + } +} diff --git a/RelationshipAnalysis/Program.cs b/RelationshipAnalysis/Program.cs index afa8fec..2795e69 100644 --- a/RelationshipAnalysis/Program.cs +++ b/RelationshipAnalysis/Program.cs @@ -4,6 +4,7 @@ using DotNetEnv; using Microsoft.EntityFrameworkCore; using RelationshipAnalysis.Context; +using RelationshipAnalysis.Middlewares; using RelationshipAnalysis.Services; using RelationshipAnalysis.Services.AccessServices; using RelationshipAnalysis.Services.AccessServices.Abstraction; @@ -91,7 +92,7 @@ app.MapControllers(); app.UseCors(x => x.AllowCredentials().AllowAnyHeader().AllowAnyMethod() .SetIsOriginAllowed(x => true)); - +app.UseMiddleware(); app.Run(); public partial class Program diff --git a/RelationshipAnalysis/RelationshipAnalysis.csproj b/RelationshipAnalysis/RelationshipAnalysis.csproj index 5395809..7fe42ca 100644 --- a/RelationshipAnalysis/RelationshipAnalysis.csproj +++ b/RelationshipAnalysis/RelationshipAnalysis.csproj @@ -9,6 +9,7 @@ +