diff --git a/Refresh.Common/Extensions/HttpContentExtensions.cs b/Refresh.Common/Extensions/HttpContentExtensions.cs new file mode 100644 index 00000000..355e391d --- /dev/null +++ b/Refresh.Common/Extensions/HttpContentExtensions.cs @@ -0,0 +1,20 @@ +using System.Xml; +using System.Xml.Serialization; +using Newtonsoft.Json; + +namespace Refresh.Common.Extensions; + +public static class HttpContentExtensions +{ + public static T ReadAsXml(this HttpContent content) + { + XmlSerializer serializer = new(typeof(T)); + + return (T)serializer.Deserialize(new XmlTextReader(content.ReadAsStream()))!; + } + + public static T? ReadAsJson(this HttpContent content) + { + return JsonConvert.DeserializeObject(content.ReadAsStringAsync().Result); + } +} \ No newline at end of file diff --git a/Refresh.Common/Extensions/NameValueCollectionExtensions.cs b/Refresh.Common/Extensions/NameValueCollectionExtensions.cs new file mode 100644 index 00000000..d14e5678 --- /dev/null +++ b/Refresh.Common/Extensions/NameValueCollectionExtensions.cs @@ -0,0 +1,35 @@ +using System.Collections.Specialized; +using System.Text; +using System.Web; + +namespace Refresh.Common.Extensions; + +public static class NameValueCollectionExtensions +{ + public static string ToQueryString(this NameValueCollection queryParams) + { + StringBuilder builder = new(); + + if (queryParams.Count == 0) + return string.Empty; + + builder.Append('?'); + for (int i = 0; i < queryParams.Count; i++) + { + string? key = queryParams.GetKey(i); + string? val = queryParams.Get(i); + + if (key == null) + continue; + + builder.Append(HttpUtility.UrlEncode(key)); + builder.Append('='); + if(val != null) + builder.Append(HttpUtility.UrlEncode(val)); + + builder.Append('&'); + } + + return builder.ToString(); + } +} \ No newline at end of file diff --git a/Refresh.Common/Helpers/CryptoHelper.cs b/Refresh.Common/Helpers/CryptoHelper.cs new file mode 100644 index 00000000..2c853c71 --- /dev/null +++ b/Refresh.Common/Helpers/CryptoHelper.cs @@ -0,0 +1,16 @@ +using System.Security.Cryptography; + +namespace Refresh.Common.Helpers; + +public static class CryptoHelper +{ + public static string GetRandomBase64String(int length) + { + byte[] tokenData = new byte[length]; + + using RandomNumberGenerator rng = RandomNumberGenerator.Create(); + rng.GetBytes(tokenData); + + return Convert.ToBase64String(tokenData); + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Configuration/IntegrationConfig.cs b/Refresh.GameServer/Configuration/IntegrationConfig.cs index 4ba760a0..c3b22e44 100644 --- a/Refresh.GameServer/Configuration/IntegrationConfig.cs +++ b/Refresh.GameServer/Configuration/IntegrationConfig.cs @@ -7,7 +7,7 @@ namespace Refresh.GameServer.Configuration; /// public class IntegrationConfig : Config { - public override int CurrentConfigVersion => 6; + public override int CurrentConfigVersion => 9; public override int Version { get; set; } protected override void Migrate(int oldVer, dynamic oldConfig) { @@ -38,6 +38,50 @@ protected override void Migrate(int oldVer, dynamic oldConfig) public string DiscordNickname { get; set; } = "Refresh"; public string DiscordAvatarUrl { get; set; } = "https://raw.githubusercontent.com/LittleBigRefresh/Branding/main/icons/refresh_512x.png"; + #endregion + + #region Discord OAuth + + /// + /// Whether to enable discord OAuth support for account linking + /// + public bool DiscordOAuthEnabled { get; set; } + + /// + /// The redirect URL to use for Discord OAuth requests, ex. `https://lbp.littlebigrefresh.com/api/v3/oauth/authenticate` + /// + public string DiscordOAuthRedirectUrl { get; set; } = "http://localhost:10061/api/v3/oauth/authenticate"; + /// + /// The client ID of the OAuth application + /// + public string DiscordOAuthClientId { get; set; } = ""; + /// + /// The client secret of the OAuth application + /// + public string DiscordOAuthClientSecret { get; set; } = ""; + + #endregion + + #region GitHub OAuth + + /// + /// Whether to enable GitHub OAuth support for account linking + /// + public bool GitHubOAuthEnabled { get; set; } + + /// + /// The redirect URL to use for GitHub OAuth requests, ex. `https://lbp.littlebigrefresh.com/api/v3/oauth/authenticate` + /// + public string GitHubOAuthRedirectUrl { get; set; } = "http://localhost:10061/api/v3/oauth/authenticate"; + /// + /// The client ID of the OAuth application + /// + public string GitHubOAuthClientId { get; set; } = ""; + /// + /// The client secret of the OAuth application + /// + public string GitHubOAuthClientSecret { get; set; } = ""; + #endregion #region AIPI diff --git a/Refresh.GameServer/Database/GameDatabaseContext.OAuth.cs b/Refresh.GameServer/Database/GameDatabaseContext.OAuth.cs new file mode 100644 index 00000000..89c23bb4 --- /dev/null +++ b/Refresh.GameServer/Database/GameDatabaseContext.OAuth.cs @@ -0,0 +1,98 @@ +using Refresh.Common.Helpers; +using Refresh.GameServer.Time; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.OAuth.Discord; +using Refresh.GameServer.Types.UserData; +using OAuthRequest = Refresh.GameServer.Types.OAuth.OAuthRequest; + +namespace Refresh.GameServer.Database; + +public partial class GameDatabaseContext // oauth +{ + public string CreateOAuthRequest(GameUser user, IDateTimeProvider timeProvider, OAuthProvider provider) + { + string state = CryptoHelper.GetRandomBase64String(128); + + this.Write(() => + { + this.OAuthRequests.Add(new OAuthRequest + { + User = user, + State = state, + ExpiresAt = timeProvider.Now + TimeSpan.FromHours(1), // 1 hour expiry + Provider = provider, + }); + }); + + return state; + } + + /// + /// Returns the OAuthProvider used in a request + /// + /// The OAuth request state + /// The provider, or null if no request was found with that state + public OAuthProvider? OAuthGetProviderForRequest(string state) + => this.OAuthRequests.FirstOrDefault(d => d.State == state)?.Provider; + + public GameUser SaveOAuthToken(string state, OAuth2AccessTokenResponse tokenResponse, IDateTimeProvider timeProvider, OAuthProvider provider) + { + OAuthRequest request = this.OAuthRequests.First(d => d.State == state); + GameUser user = request.User; + + this.Write(() => + { + OAuthTokenRelation? relation = this.OAuthTokenRelations.FirstOrDefault(d => d.User == request.User && d._Provider == (int)provider); + if (relation == null) + { + this.OAuthTokenRelations.Add(relation = new OAuthTokenRelation + { + User = request.User, + Provider = request.Provider, + }); + } + + this.UpdateOAuthToken(relation, tokenResponse, timeProvider); + + this.OAuthRequests.Remove(request); + }); + + return user; + } + + public void UpdateOAuthToken(OAuthTokenRelation token, OAuth2AccessTokenResponse tokenResponse, IDateTimeProvider timeProvider) + { + this.Write(() => + { + token.AccessToken = tokenResponse.AccessToken; + token.RefreshToken = tokenResponse.RefreshToken; + // If we don't have a revocation date, then we assume it never expires, and will just handle a revoked token at request time + token.AccessTokenRevocationTime = tokenResponse.ExpiresIn == null ? DateTimeOffset.MaxValue : timeProvider.Now + TimeSpan.FromSeconds(tokenResponse.ExpiresIn.Value); + }); + } + + public OAuthTokenRelation? GetOAuthTokenFromUser(GameUser user, OAuthProvider provider) + => this.OAuthTokenRelations.FirstOrDefault(d => d.User == user && d._Provider == (int)provider); + + public int RemoveAllExpiredOAuthRequests(IDateTimeProvider timeProvider) + { + IQueryable expired = this.OAuthRequests.Where(d => d.ExpiresAt < timeProvider.Now); + + int removed = expired.Count(); + + this.Write(() => + { + this.OAuthRequests.RemoveRange(expired); + }); + + return removed; + } + + public void RevokeOAuthToken(OAuthTokenRelation token) + { + this.Write(() => + { + this.OAuthTokenRelations.Remove(token); + }); + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Database/GameDatabaseContext.Tokens.cs b/Refresh.GameServer/Database/GameDatabaseContext.Tokens.cs index 6c7885aa..7683ba1a 100644 --- a/Refresh.GameServer/Database/GameDatabaseContext.Tokens.cs +++ b/Refresh.GameServer/Database/GameDatabaseContext.Tokens.cs @@ -1,5 +1,6 @@ using System.Security.Cryptography; using JetBrains.Annotations; +using Refresh.Common.Helpers; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.UserData; @@ -23,16 +24,6 @@ static GameDatabaseContext() GameCookieLength = (int)Math.Floor((MaxGameCookieLength - GameCookieHeader.Length - MaxBase64Padding) * 3 / 4.0); } - private static string GetTokenString(int length) - { - byte[] tokenData = new byte[length]; - - using RandomNumberGenerator rng = RandomNumberGenerator.Create(); - rng.GetBytes(tokenData); - - return Convert.ToBase64String(tokenData); - } - public Token GenerateTokenForUser(GameUser user, TokenType type, TokenGame game, TokenPlatform platform, string ipAddress, int tokenExpirySeconds = DefaultTokenExpirySeconds) { // TODO: JWT (JSON Web Tokens) for TokenType.Api @@ -42,7 +33,7 @@ public Token GenerateTokenForUser(GameUser user, TokenType type, TokenGame game, Token token = new() { User = user, - TokenData = GetTokenString(cookieLength), + TokenData = CryptoHelper.GetRandomBase64String(cookieLength), TokenType = type, TokenGame = game, TokenPlatform = platform, diff --git a/Refresh.GameServer/Database/GameDatabaseContext.Users.cs b/Refresh.GameServer/Database/GameDatabaseContext.Users.cs index 603b5996..ac40ca26 100644 --- a/Refresh.GameServer/Database/GameDatabaseContext.Users.cs +++ b/Refresh.GameServer/Database/GameDatabaseContext.Users.cs @@ -203,6 +203,9 @@ public void UpdateUserData(GameUser user, ApiUpdateUserRequest data) if (data.ShowModdedContent != null) user.ShowModdedContent = data.ShowModdedContent.Value; + + if (data.DiscordProfileVisibility != null) + user.DiscordProfileVisibility = data.DiscordProfileVisibility.Value; }); } diff --git a/Refresh.GameServer/Database/GameDatabaseContext.cs b/Refresh.GameServer/Database/GameDatabaseContext.cs index 17d0b70b..0850dbac 100644 --- a/Refresh.GameServer/Database/GameDatabaseContext.cs +++ b/Refresh.GameServer/Database/GameDatabaseContext.cs @@ -12,12 +12,15 @@ using Refresh.GameServer.Types.Contests; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Notifications; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.OAuth.Discord; using Refresh.GameServer.Types.Photos; using Refresh.GameServer.Types.Playlists; using Refresh.GameServer.Types.Relations; using Refresh.GameServer.Types.Reviews; using Refresh.GameServer.Types.UserData; using Refresh.GameServer.Types.UserData.Leaderboard; +using OAuthRequest = Refresh.GameServer.Types.OAuth.OAuthRequest; namespace Refresh.GameServer.Database; @@ -61,6 +64,8 @@ public partial class GameDatabaseContext : RealmDatabaseContext private RealmDbSet GamePlaylists => new(this._realm); private RealmDbSet LevelPlaylistRelations => new(this._realm); private RealmDbSet SubPlaylistRelations => new(this._realm); + private RealmDbSet OAuthRequests => new(this._realm); + private RealmDbSet OAuthTokenRelations => new(this._realm); internal GameDatabaseContext(IDateTimeProvider time) { diff --git a/Refresh.GameServer/Database/GameDatabaseProvider.cs b/Refresh.GameServer/Database/GameDatabaseProvider.cs index 0fdce84b..022d6763 100644 --- a/Refresh.GameServer/Database/GameDatabaseProvider.cs +++ b/Refresh.GameServer/Database/GameDatabaseProvider.cs @@ -12,11 +12,14 @@ using Refresh.GameServer.Types.Contests; using Refresh.GameServer.Types.Levels.SkillRewards; using Refresh.GameServer.Types.Notifications; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.OAuth.Discord; using Refresh.GameServer.Types.Relations; using Refresh.GameServer.Types.Reviews; using Refresh.GameServer.Types.UserData.Leaderboard; using Refresh.GameServer.Types.Photos; using Refresh.GameServer.Types.Playlists; +using OAuthRequest = Refresh.GameServer.Types.OAuth.OAuthRequest; namespace Refresh.GameServer.Database; @@ -34,7 +37,7 @@ protected GameDatabaseProvider(IDateTimeProvider time) this._time = time; } - protected override ulong SchemaVersion => 159; + protected override ulong SchemaVersion => 162; protected override string Filename => "refreshGameServer.realm"; @@ -86,6 +89,10 @@ protected GameDatabaseProvider(IDateTimeProvider time) typeof(GamePlaylist), typeof(LevelPlaylistRelation), typeof(SubPlaylistRelation), + + // oauth + typeof(OAuthRequest), + typeof(OAuthTokenRelation), ]; public override void Warmup() diff --git a/Refresh.GameServer/Endpoints/ApiV3/ApiTypes/Errors/ApiNotFoundError.cs b/Refresh.GameServer/Endpoints/ApiV3/ApiTypes/Errors/ApiNotFoundError.cs index 85cba30a..2c3cf778 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/ApiTypes/Errors/ApiNotFoundError.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/ApiTypes/Errors/ApiNotFoundError.cs @@ -21,7 +21,13 @@ public class ApiNotFoundError : ApiError public const string ContestMissingErrorWhen = "The contest could not be found"; public static readonly ApiNotFoundError ContestMissingError = new(ContestMissingErrorWhen); + + public const string OAuthTokenMissingErrorWhen = "An OAuth token for this user could not be found"; + public static readonly ApiNotFoundError OAuthTokenMissingError = new(OAuthTokenMissingErrorWhen); + public const string OAuthProviderMissingErrorWhen = "The OAuth provider could not be found"; + public static readonly ApiNotFoundError OAuthProviderMissingError = new(OAuthProviderMissingErrorWhen); + private ApiNotFoundError() : base("The requested resource was not found", NotFound) {} diff --git a/Refresh.GameServer/Endpoints/ApiV3/ApiTypes/Errors/ApiNotSupportedError.cs b/Refresh.GameServer/Endpoints/ApiV3/ApiTypes/Errors/ApiNotSupportedError.cs new file mode 100644 index 00000000..c0307986 --- /dev/null +++ b/Refresh.GameServer/Endpoints/ApiV3/ApiTypes/Errors/ApiNotSupportedError.cs @@ -0,0 +1,18 @@ +namespace Refresh.GameServer.Endpoints.ApiV3.ApiTypes.Errors; + +public class ApiNotSupportedError : ApiError +{ + public static readonly ApiNotSupportedError Instance = new(); + + public const string OAuthProviderTokenRevocationUnsupportedErrorWhen = "This OAuth provider does not support token revocation"; + public static readonly ApiNotSupportedError OAuthProviderTokenRevocationUnsupportedError = new(OAuthProviderTokenRevocationUnsupportedErrorWhen); + + public const string OAuthProviderDisabledErrorWhen = "The server does not have this OAuth provider enabled"; + public static readonly ApiNotSupportedError OAuthProviderDisabledError = new(OAuthProviderDisabledErrorWhen); + + private ApiNotSupportedError() : base("The server is not configured to support this endpoint.", NotImplemented) + {} + + private ApiNotSupportedError(string message) : base(message, NotImplemented) + {} +} \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/IApiResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/IApiResponse.cs index 8f8a00c4..82a8c79c 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/IApiResponse.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/IApiResponse.cs @@ -1,6 +1,3 @@ namespace Refresh.GameServer.Endpoints.ApiV3.DataTypes; -public interface IApiResponse -{ - -} \ No newline at end of file +public interface IApiResponse; \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Request/ApiUpdateUserRequest.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Request/ApiUpdateUserRequest.cs index 7ae1e649..a3f9be37 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Request/ApiUpdateUserRequest.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Request/ApiUpdateUserRequest.cs @@ -20,4 +20,5 @@ public class ApiUpdateUserRequest public Visibility? LevelVisibility { get; set; } public Visibility? ProfileVisibility { get; set; } + public Visibility? DiscordProfileVisibility { get; set; } } \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/ApiInstanceResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/ApiInstanceResponse.cs index 0d056530..4d9419cd 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/ApiInstanceResponse.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/ApiInstanceResponse.cs @@ -44,6 +44,7 @@ public class ApiInstanceResponse : IApiResponse public required IEnumerable Announcements { get; set; } public required ApiRichPresenceConfigurationResponse RichPresenceConfiguration { get; set; } + public required bool DiscordOAuthEnabled { get; set; } public required bool MaintenanceModeEnabled { get; set; } public required string? GrafanaDashboardUrl { get; set; } diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/ApiOAuthBeginAuthenticationResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/ApiOAuthBeginAuthenticationResponse.cs new file mode 100644 index 00000000..5c4bfe33 --- /dev/null +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/ApiOAuthBeginAuthenticationResponse.cs @@ -0,0 +1,6 @@ +namespace Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth; + +public class ApiOAuthBeginAuthenticationResponse : IApiResponse +{ + public required string AuthorizationUrl { get; set; } +} \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/Discord/ApiDiscordUserResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/Discord/ApiDiscordUserResponse.cs new file mode 100644 index 00000000..38dcb838 --- /dev/null +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/Discord/ApiDiscordUserResponse.cs @@ -0,0 +1,61 @@ +using Refresh.GameServer.Types.Data; +using Refresh.GameServer.Types.OAuth.Discord.Api; + +namespace Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth.Discord; + +[JsonObject(NamingStrategyType = typeof(CamelCaseNamingStrategy))] +public class ApiDiscordUserResponse : IApiResponse, IDataConvertableFrom +{ + /// + /// The user's ID, as a snowflake + /// + public required ulong Id { get; set; } + /// + /// The user's username + /// + public required string Username { get; set; } + /// + /// The user's discord tag + /// + public required string Discriminator { get; set; } + /// + /// The user's global name, if set + /// + public required string? GlobalName { get; set; } + /// + /// The hash of the user's avatar + /// + public required string? AvatarUrl { get; set; } + /// + /// The hash of the user's banner + /// + public required string? BannerUrl { get; set; } + /// + /// The user's accent colour + /// + public required uint? AccentColor { get; set; } + + public static ApiDiscordUserResponse? FromOld(DiscordApiUserResponse? old, DataContext dataContext) + { + if (old == null) + return null; + + return new ApiDiscordUserResponse + { + Id = old.Id, + Username = old.Username, + Discriminator = old.Discriminator, + GlobalName = old.GlobalName, + AvatarUrl = old.Avatar == null + ? "https://cdn.discordapp.com/embed/avatars/0.png" + : $"https://cdn.discordapp.com/avatars/{old.Id}/{old.Avatar}?size=512", + BannerUrl = old.Banner == null + ? null + : $"https://cdn.discordapp.com/banners/{old.Id}/{old.Banner}?size=512", + AccentColor = old.AccentColor, + }; + } + + public static IEnumerable FromOldList(IEnumerable oldList, + DataContext dataContext) => oldList.Select(d => FromOld(d, dataContext)!); +} \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/GitHub/ApiGitHubUserResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/GitHub/ApiGitHubUserResponse.cs new file mode 100644 index 00000000..7e8370c8 --- /dev/null +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/OAuth/GitHub/ApiGitHubUserResponse.cs @@ -0,0 +1,30 @@ +using Refresh.GameServer.Types.Data; +using Refresh.GameServer.Types.OAuth.GitHub; + +namespace Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth.GitHub; + +[JsonObject(NamingStrategyType = typeof(CamelCaseNamingStrategy))] +public class ApiGitHubUserResponse : IApiResponse, IDataConvertableFrom +{ + public required string? Username { get; set; } + public required string? Name { get; set; } + public required string? ProfileUrl { get; set; } + public required string? AvatarUrl { get; set; } + + public static ApiGitHubUserResponse? FromOld(GitHubApiUserResponse? old, DataContext dataContext) + { + if (old == null) + return null; + + return new ApiGitHubUserResponse + { + Username = old.Login, + Name = old.Name, + ProfileUrl = old.HtmlUrl.ToString(), + AvatarUrl = old.AvatarUrl.ToString(), + }; + } + + public static IEnumerable FromOldList(IEnumerable oldList, DataContext dataContext) + => oldList.Select(d => FromOld(d, dataContext)!); +} \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs index d18d59ce..f95bdb97 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiExtendedGameUserResponse.cs @@ -1,9 +1,13 @@ using JetBrains.Annotations; using Refresh.GameServer.Authentication; using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.Data; +using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth.Discord; +using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth.GitHub; using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.Users.Rooms; +using Refresh.GameServer.Services.OAuth.Clients; using Refresh.GameServer.Types; using Refresh.GameServer.Types.Data; +using Refresh.GameServer.Types.OAuth; using Refresh.GameServer.Types.Roles; using Refresh.GameServer.Types.UserData; @@ -41,12 +45,16 @@ public class ApiExtendedGameUserResponse : IApiResponse, IDataConvertableFrom null; user:notnull => notnull")] public static ApiExtendedGameUserResponse? FromOld(GameUser? user, DataContext dataContext) @@ -77,8 +85,15 @@ public class ApiExtendedGameUserResponse : IApiResponse, IDataConvertableFrom(OAuthProvider.Discord) + ?.GetUserInformation(dataContext.Database, dataContext.TimeProvider, user), dataContext), + GitHubProfileInfo = ApiGitHubUserResponse.FromOld(dataContext.OAuth + .GetOAuthClient(OAuthProvider.GitHub) + ?.GetUserInformation(dataContext.Database, dataContext.TimeProvider, user), dataContext), }; } diff --git a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiGameUserResponse.cs b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiGameUserResponse.cs index 84f04f0e..24fbc53d 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiGameUserResponse.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/DataTypes/Response/Users/ApiGameUserResponse.cs @@ -1,8 +1,12 @@ using JetBrains.Annotations; -using Refresh.GameServer.Authentication; using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.Data; +using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth.Discord; +using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth.GitHub; using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.Users.Rooms; +using Refresh.GameServer.Services.OAuth.Clients; +using Refresh.GameServer.Types; using Refresh.GameServer.Types.Data; +using Refresh.GameServer.Types.OAuth; using Refresh.GameServer.Types.Roles; using Refresh.GameServer.Types.UserData; @@ -29,12 +33,14 @@ public class ApiGameUserResponse : IApiResponse, IDataConvertableFrom null; notnull => notnull")] + [ContractAnnotation("user:null => null; user:notnull => notnull")] public static ApiGameUserResponse? FromOld(GameUser? user, DataContext dataContext) { if (user == null) return null; - + return new ApiGameUserResponse { UserId = user.UserId.ToString()!, @@ -52,6 +58,21 @@ public class ApiGameUserResponse : IApiResponse, IDataConvertableFrom(OAuthProvider.Discord) + ?.GetUserInformation(dataContext.Database, dataContext.TimeProvider, user), dataContext) + ), + GitHubProfileInfo = user.GitHubProfileVisibility.Filter( + user, + dataContext, + ApiGitHubUserResponse.FromOld(dataContext.OAuth + .GetOAuthClient(OAuthProvider.GitHub) + ?.GetUserInformation(dataContext.Database, dataContext.TimeProvider, user), dataContext) + ), }; } diff --git a/Refresh.GameServer/Endpoints/ApiV3/InstanceApiEndpoints.cs b/Refresh.GameServer/Endpoints/ApiV3/InstanceApiEndpoints.cs index 74a7c55c..8eb7c9bd 100644 --- a/Refresh.GameServer/Endpoints/ApiV3/InstanceApiEndpoints.cs +++ b/Refresh.GameServer/Endpoints/ApiV3/InstanceApiEndpoints.cs @@ -66,6 +66,7 @@ public ApiResponse GetInstanceInformation(RequestContext co gameConfig, richConfig), dataContext)!, GrafanaDashboardUrl = integrationConfig.GrafanaDashboardUrl, + DiscordOAuthEnabled = integrationConfig.DiscordOAuthEnabled, ContactInfo = new ApiContactInfoResponse { diff --git a/Refresh.GameServer/Endpoints/ApiV3/OAuth/DiscordOAuthEndpoints.cs b/Refresh.GameServer/Endpoints/ApiV3/OAuth/DiscordOAuthEndpoints.cs new file mode 100644 index 00000000..0c6fd8bc --- /dev/null +++ b/Refresh.GameServer/Endpoints/ApiV3/OAuth/DiscordOAuthEndpoints.cs @@ -0,0 +1,43 @@ +using AttribDoc.Attributes; +using Bunkum.Core; +using Bunkum.Core.Endpoints; +using Refresh.GameServer.Database; +using Refresh.GameServer.Endpoints.ApiV3.ApiTypes; +using Refresh.GameServer.Endpoints.ApiV3.ApiTypes.Errors; +using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth.Discord; +using Refresh.GameServer.Services.OAuth; +using Refresh.GameServer.Services.OAuth.Clients; +using Refresh.GameServer.Time; +using Refresh.GameServer.Types.Data; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.OAuth.Discord.Api; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Endpoints.ApiV3.OAuth; + +public class DiscordOAuthEndpoints : EndpointGroup +{ + [ApiV3Endpoint("oauth/discord/currentUserInformation")] + [DocSummary("Gets information about the current user's linked Discord account")] + [DocError(typeof(ApiNotSupportedError), ApiNotSupportedError.OAuthProviderDisabledErrorWhen)] + [DocError(typeof(ApiNotFoundError), ApiNotFoundError.OAuthTokenMissingErrorWhen)] + [DocResponseBody(typeof(ApiDiscordUserResponse))] + public ApiResponse CurrentUserInformation( + RequestContext context, + GameDatabaseContext database, + OAuthService oAuthService, + GameUser user, + IDateTimeProvider timeProvider, + DataContext dataContext) + { + if (!oAuthService.GetOAuthClient(OAuthProvider.Discord, out DiscordOAuthClient? client)) + return ApiNotSupportedError.OAuthProviderDisabledError; + + DiscordApiUserResponse? userInformation = client.GetUserInformation(database, timeProvider, user); + + if (userInformation == null) + return ApiNotFoundError.OAuthTokenMissingError; + + return ApiDiscordUserResponse.FromOld(userInformation, dataContext); + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/OAuth/GitHubOAuthEndpoints.cs b/Refresh.GameServer/Endpoints/ApiV3/OAuth/GitHubOAuthEndpoints.cs new file mode 100644 index 00000000..3aef9ee7 --- /dev/null +++ b/Refresh.GameServer/Endpoints/ApiV3/OAuth/GitHubOAuthEndpoints.cs @@ -0,0 +1,42 @@ +using AttribDoc.Attributes; +using Bunkum.Core; +using Bunkum.Core.Endpoints; +using Refresh.GameServer.Database; +using Refresh.GameServer.Endpoints.ApiV3.ApiTypes; +using Refresh.GameServer.Endpoints.ApiV3.ApiTypes.Errors; +using Refresh.GameServer.Services.OAuth; +using Refresh.GameServer.Services.OAuth.Clients; +using Refresh.GameServer.Time; +using Refresh.GameServer.Types.Data; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.OAuth.GitHub; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Endpoints.ApiV3.OAuth; + +public class GitHubOAuthEndpoints : EndpointGroup +{ + [ApiV3Endpoint("oauth/github/currentUserInformation")] + [DocSummary("Gets information about the current user's linked GitHub account")] + [DocError(typeof(ApiNotSupportedError), ApiNotSupportedError.OAuthProviderDisabledErrorWhen)] + [DocError(typeof(ApiNotFoundError), ApiNotFoundError.OAuthTokenMissingErrorWhen)] + [DocResponseBody(typeof(GitHubApiUserResponse))] + public ApiResponse CurrentUserInformation( + RequestContext context, + GameDatabaseContext database, + OAuthService oAuthService, + GameUser user, + IDateTimeProvider timeProvider, + DataContext dataContext) + { + if (!oAuthService.GetOAuthClient(OAuthProvider.GitHub, out GitHubOAuthClient? client)) + return ApiNotSupportedError.OAuthProviderDisabledError; + + GitHubApiUserResponse? userInformation = client.GetUserInformation(database, timeProvider, user); + + if (userInformation == null) + return ApiNotFoundError.OAuthTokenMissingError; + + return userInformation; + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Endpoints/ApiV3/OAuth/OAuthEndpoints.cs b/Refresh.GameServer/Endpoints/ApiV3/OAuth/OAuthEndpoints.cs new file mode 100644 index 00000000..1dfcdba5 --- /dev/null +++ b/Refresh.GameServer/Endpoints/ApiV3/OAuth/OAuthEndpoints.cs @@ -0,0 +1,117 @@ +using System.Diagnostics; +using AttribDoc.Attributes; +using Bunkum.Core; +using Bunkum.Core.Endpoints; +using Bunkum.Core.Responses; +using Bunkum.Protocols.Http; +using Refresh.GameServer.Configuration; +using Refresh.GameServer.Database; +using Refresh.GameServer.Endpoints.ApiV3.ApiTypes; +using Refresh.GameServer.Endpoints.ApiV3.ApiTypes.Errors; +using Refresh.GameServer.Endpoints.ApiV3.DataTypes.Response.OAuth; +using Refresh.GameServer.Services.OAuth; +using Refresh.GameServer.Services.OAuth.Clients; +using Refresh.GameServer.Time; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Endpoints.ApiV3.OAuth; + +public class OAuthEndpoints : EndpointGroup +{ + [ApiV3Endpoint("oauth/{providerStr}/beginAuthentication")] + [DocSummary("Begins the OAuth authentication process with the specified provider.")] + [DocError(typeof(ApiNotFoundError), ApiNotFoundError.OAuthProviderMissingErrorWhen)] + [DocError(typeof(ApiNotSupportedError), ApiNotSupportedError.OAuthProviderDisabledErrorWhen)] + [DocResponseBody(typeof(ApiOAuthBeginAuthenticationResponse))] + public ApiResponse BeginAuthentication( + RequestContext context, + GameUser user, + GameDatabaseContext database, + OAuthService oAuthService, + IDateTimeProvider timeProvider, + [DocSummary("The OAuth provider to use for the authentication process")] string providerStr) + { + if (!Enum.TryParse(providerStr, true, out OAuthProvider provider)) + return ApiNotFoundError.OAuthProviderMissingError; + + if (!oAuthService.GetOAuthClient(provider, out OAuthClient? client)) + return ApiNotSupportedError.OAuthProviderTokenRevocationUnsupportedError; + + // Create a new OAuth request + string state = database.CreateOAuthRequest(user, timeProvider, provider); + + return new ApiOAuthBeginAuthenticationResponse + { + AuthorizationUrl = client.GetOAuthAuthorizationUrl(state), + }; + } + + [ApiV3Endpoint("oauth/authenticate"), Authentication(false)] + [DocSummary("Finishes OAuth authentication and saves the token to the database. " + + "This isn't meant to be called normally, and is intended as a redirect target of an OAuth authorization request")] + public Response Authenticate( + RequestContext context, + OAuthService oAuthService, + GameServerConfig config, + GameDatabaseContext database, + IDateTimeProvider timeProvider) + { + string? authCode = context.QueryString["code"]; + if (authCode == null) + return BadRequest; + + string? state = context.QueryString["state"]; + if (state == null) + return BadRequest; + + OAuthProvider? provider = database.OAuthGetProviderForRequest(state); + // If the request doesn't exist, then it probably expired or the state data is invalid somehow + if (provider == null) + return BadRequest; + + OAuthClient? client = oAuthService.GetOAuthClient(provider.Value); + Debug.Assert(client != null); + + OAuth2AccessTokenResponse response = client.AcquireTokenFromAuthorizationCode(authCode); + + // Save the OAuth token to the database + GameUser user = database.SaveOAuthToken(state, response, timeProvider, provider.Value); + + context.ResponseHeaders["Location"] = config.WebExternalUrl; + + database.AddNotification("Account Linking Success", $"Your account has been successfully linked to {provider}!", user); + + return Redirect; + } + + [ApiV3Endpoint("oauth/{providerStr}/revokeToken", HttpMethods.Post)] + [DocSummary("Revokes the current user's OAuth token for the specified provider")] + [DocError(typeof(ApiNotFoundError), ApiNotFoundError.OAuthProviderMissingErrorWhen)] + [DocError(typeof(ApiNotFoundError), ApiNotFoundError.OAuthTokenMissingErrorWhen)] + [DocError(typeof(ApiNotSupportedError), ApiNotSupportedError.OAuthProviderTokenRevocationUnsupportedErrorWhen)] + [DocError(typeof(ApiNotSupportedError), ApiNotSupportedError.OAuthProviderDisabledErrorWhen)] + public ApiResponse RevokeToken(RequestContext context, + GameDatabaseContext database, + OAuthService oAuthService, + GameUser user, + [DocSummary("The OAuth provider which provided the token to revoke")] string providerStr) + { + if (!Enum.TryParse(providerStr, true, out OAuthProvider provider)) + return ApiNotFoundError.OAuthProviderMissingError; + + if (!oAuthService.GetOAuthClient(provider, out OAuthClient? client)) + return ApiNotSupportedError.OAuthProviderDisabledError; + + if (!client.TokenRevocationSupported) + return ApiNotSupportedError.OAuthProviderTokenRevocationUnsupportedError; + + OAuthTokenRelation? token = database.GetOAuthTokenFromUser(user, provider); + if (token == null) + return ApiNotFoundError.OAuthTokenMissingError; + + client.RevokeToken(database, token); + + return new ApiOkResponse(); + } +} \ No newline at end of file diff --git a/Refresh.GameServer/RefreshContext.cs b/Refresh.GameServer/RefreshContext.cs index 8a27389d..03825feb 100644 --- a/Refresh.GameServer/RefreshContext.cs +++ b/Refresh.GameServer/RefreshContext.cs @@ -11,4 +11,5 @@ public enum RefreshContext Publishing, Aipi, Presence, + OAuth, } \ No newline at end of file diff --git a/Refresh.GameServer/RefreshGameServer.cs b/Refresh.GameServer/RefreshGameServer.cs index e68a91b4..eb28e30a 100644 --- a/Refresh.GameServer/RefreshGameServer.cs +++ b/Refresh.GameServer/RefreshGameServer.cs @@ -20,6 +20,7 @@ using Refresh.GameServer.Importing; using Refresh.GameServer.Middlewares; using Refresh.GameServer.Services; +using Refresh.GameServer.Services.OAuth; using Refresh.GameServer.Storage; using Refresh.GameServer.Time; using Refresh.GameServer.Types.Data; @@ -40,6 +41,7 @@ public class RefreshGameServer : RefreshServer protected readonly IDataStore _dataStore; protected MatchService _matchService = null!; protected GuidCheckerService _guidCheckerService = null!; + protected OAuthService OAuthService = null!; protected GameServerConfig? _config; protected IntegrationConfig? _integrationConfig; @@ -130,11 +132,10 @@ protected override void SetupServices() usesCustomDigestKey: true, serverDescription: this._config.InstanceDescription, bannerImageUrl: "https://github.com/LittleBigRefresh/Branding/blob/main/logos/refresh_type.png?raw=true"); - - this.Server.AddHealthCheckService(this._databaseProvider, new [] - { + + this.Server.AddHealthCheckService(this._databaseProvider, [ typeof(RealmDatabaseHealthCheck), - }); + ]); this.Server.AddService(); this.Server.AddService(); @@ -147,6 +148,8 @@ protected override void SetupServices() if(this._integrationConfig!.AipiEnabled) this.Server.AddService(); + this.Server.AddService(this.OAuthService = new OAuthService(this.Logger, this._integrationConfig)); + #if DEBUG this.Server.AddService(); #endif @@ -159,7 +162,8 @@ protected override void SetupServices() protected virtual void SetupWorkers() { - this.WorkerManager = new WorkerManager(this.Logger, this._dataStore, this._databaseProvider, this._matchService, this._guidCheckerService); + this.WorkerManager = new WorkerManager(this.Logger, this._dataStore, this._databaseProvider, this._matchService, + this._guidCheckerService, this.GetTimeProvider(), this.OAuthService); this.WorkerManager.AddWorker(); this.WorkerManager.AddWorker(); diff --git a/Refresh.GameServer/Services/OAuth/Clients/DiscordOAuthClient.cs b/Refresh.GameServer/Services/OAuth/Clients/DiscordOAuthClient.cs new file mode 100644 index 00000000..706b4506 --- /dev/null +++ b/Refresh.GameServer/Services/OAuth/Clients/DiscordOAuthClient.cs @@ -0,0 +1,75 @@ +using System.Collections.Specialized; +using System.Net.Http.Headers; +using NotEnoughLogs; +using Refresh.Common.Extensions; +using Refresh.GameServer.Configuration; +using Refresh.GameServer.Database; +using Refresh.GameServer.Time; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.OAuth.Discord.Api; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Services.OAuth.Clients; + +public class DiscordOAuthClient : OAuthClient +{ + public DiscordOAuthClient(Logger logger, IntegrationConfig integrationConfig) : base(logger) + { + this.ClientId = integrationConfig.DiscordOAuthClientId; + this.ClientSecret = integrationConfig.DiscordOAuthClientSecret; + this.RedirectUri = integrationConfig.DiscordOAuthRedirectUrl; + } + + public override OAuthProvider Provider => OAuthProvider.Discord; + + protected override string TokenEndpoint => "https://discord.com/api/oauth2/token"; + protected override string TokenRevocationEndpoint => "https://discord.com/api/oauth2/token/revoke"; + public override bool TokenRevocationSupported => true; + protected override string ClientId { get; } + protected override string ClientSecret { get; } + protected override string RedirectUri { get; } + + /// + public override string GetOAuthAuthorizationUrl(string state) + { + NameValueCollection queryParams = new() + { + ["client_id"] = this.ClientId, + ["response_type"] = "code", + ["state"] = state, + ["redirect_uri"] = this.RedirectUri, + ["scope"] = "identify", + }; + + return $"https://discord.com/oauth2/authorize{queryParams.ToQueryString()}"; + } + + /// + /// Gets information about a user's linked discord account + /// + /// The database used to access the user's token + /// The time provider for the current request + /// The user to get information on + /// The acquired user information, or null if the token is missing/expired + public DiscordApiUserResponse? GetUserInformation(GameDatabaseContext database, IDateTimeProvider timeProvider, GameUser user) + => this.GetUserInformation(database, timeProvider, database.GetOAuthTokenFromUser(user, OAuthProvider.Discord)); + + private DiscordApiUserResponse? GetUserInformation(GameDatabaseContext database, IDateTimeProvider timeProvider, OAuthTokenRelation? token) + { + if (token == null) + return null; + + HttpResponseMessage? response = this.MakeRequest(token, () => this.CreateRequestMessage(token, HttpMethod.Get, "https://discord.com/api/users/@me"), database, timeProvider); + + if (response == null) + return null; + + if (!response.IsSuccessStatusCode) + throw new Exception($"Request returned status code {response.StatusCode}"); + + if ((response.Content.Headers.ContentLength ?? 0) == 0) + throw new Exception("Request returned no response"); + + return response.Content.ReadAsJson()!; + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Services/OAuth/Clients/GitHubOAuthClient.cs b/Refresh.GameServer/Services/OAuth/Clients/GitHubOAuthClient.cs new file mode 100644 index 00000000..c0b3c9ba --- /dev/null +++ b/Refresh.GameServer/Services/OAuth/Clients/GitHubOAuthClient.cs @@ -0,0 +1,112 @@ +using System.Collections.Specialized; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using System.Text; +using System.Web; +using NotEnoughLogs; +using Refresh.Common.Extensions; +using Refresh.GameServer.Configuration; +using Refresh.GameServer.Database; +using Refresh.GameServer.Time; +using Refresh.GameServer.Types.OAuth; +using Refresh.GameServer.Types.OAuth.Discord.Api; +using Refresh.GameServer.Types.OAuth.GitHub; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Services.OAuth.Clients; + +public class GitHubOAuthClient : OAuthClient +{ + public GitHubOAuthClient(Logger logger, IntegrationConfig integrationConfig) : base(logger) + { + this.ClientId = integrationConfig.GitHubOAuthClientId; + this.ClientSecret = integrationConfig.GitHubOAuthClientSecret; + this.RedirectUri = integrationConfig.GitHubOAuthRedirectUrl; + } + + public override void Initialize() + { + base.Initialize(); + + // This isn't strictly required, but will prevent future breakages, + // since GitHub commits to supporting older API versions when they make breaking changes + this.Client.DefaultRequestHeaders.Add("X-GitHub-Api-Version", "2022-11-28"); + } + + public override OAuthProvider Provider => OAuthProvider.GitHub; + protected override string TokenEndpoint => "https://github.com/login/oauth/access_token"; + protected override string TokenRevocationEndpoint => $"https://api.github.com/applications/{this.ClientId}/grant"; + public override bool TokenRevocationSupported => true; + + protected override string ClientId { get; } + protected override string ClientSecret { get; } + protected override string RedirectUri { get; } + + public override string GetOAuthAuthorizationUrl(string state) + { + NameValueCollection queryParams = new() + { + ["client_id"] = this.ClientId, + ["response_type"] = "code", + ["state"] = state, + ["redirect_uri"] = this.RedirectUri, + ["scope"] = "read:user", + }; + + return $"https://github.com/login/oauth/authorize{queryParams.ToQueryString()}"; + } + + private string GetAccessTokenBody(OAuthTokenRelation token) => + $"{{\"access_token\":\"{HttpUtility.JavaScriptStringEncode(token.AccessToken)}\"}}"; + + // GitHub doesn't implement RFC 7009, so we have to write special logic for it :/ + // See https://docs.github.com/en/rest/apps/oauth-applications?apiVersion=2022-11-28#delete-an-app-authorization + public override void RevokeToken(GameDatabaseContext database, OAuthTokenRelation token) + { + HttpRequestMessage message = new(HttpMethod.Delete, this.TokenRevocationEndpoint); + + // this particular endpoint is special, we cant revoke a token by authenticating as the token, only by authenticating as the OAuth app + message.Headers.Authorization = new AuthenticationHeaderValue("Basic", + Convert.ToBase64String(Encoding.UTF8.GetBytes($"{this.ClientId}:{this.ClientSecret}"))); + + message.Headers.Accept.Clear(); + message.Headers.Accept.ParseAdd("application/vnd.github+json"); + + message.Content = new StringContent(this.GetAccessTokenBody(token)); + + HttpResponseMessage response = this.Client.Send(message); + + // This is the success response + if (response.StatusCode == NoContent) + return; + + // This is sent when the token is already invalid + if (response.StatusCode == NotFound) + return; + + if (!response.IsSuccessStatusCode) + throw new Exception($"Got unexpected error status code {response.StatusCode} when revoking token!"); + } + + public GitHubApiUserResponse? GetUserInformation(GameDatabaseContext database, IDateTimeProvider timeProvider, GameUser user) + => this.GetUserInformation(database, timeProvider, database.GetOAuthTokenFromUser(user, OAuthProvider.GitHub)); + + public GitHubApiUserResponse? GetUserInformation(GameDatabaseContext database, IDateTimeProvider timeProvider, OAuthTokenRelation? token) + { + if (token == null) + return null; + + HttpResponseMessage? response = this.MakeRequest(token, () => this.CreateRequestMessage(token, HttpMethod.Get, "https://api.github.com/user"), database, timeProvider); + + if (response == null) + return null; + + if (!response.IsSuccessStatusCode) + throw new Exception($"Request returned status code {response.StatusCode}"); + + if ((response.Content.Headers.ContentLength ?? 0) == 0) + throw new Exception("Request returned no response"); + + return response.Content.ReadAsJson()!; + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Services/OAuth/OAuthClient.cs b/Refresh.GameServer/Services/OAuth/OAuthClient.cs new file mode 100644 index 00000000..567f6324 --- /dev/null +++ b/Refresh.GameServer/Services/OAuth/OAuthClient.cs @@ -0,0 +1,232 @@ +using System.Net.Http.Headers; +using NotEnoughLogs; +using Refresh.Common.Extensions; +using Refresh.GameServer.Database; +using Refresh.GameServer.Time; +using Refresh.GameServer.Types.OAuth; + +namespace Refresh.GameServer.Services.OAuth; + +/// +/// A minimal implementation of the OAuth2 API (RFC 6749), +/// covering the authorization code and refresh token parts of the specification. +/// +/// Also contains an implementation of the token revocation extension of the OAuth2 specification (RFC 7009). +/// +/// +/// +public abstract class OAuthClient : IDisposable +{ + protected HttpClient Client = null!; + + protected readonly Logger Logger; + + /// + /// The provider associated with this OAuth2Service + /// + public abstract OAuthProvider Provider { get; } + + protected abstract string TokenEndpoint { get; } + protected abstract string TokenRevocationEndpoint { get; } + public abstract bool TokenRevocationSupported { get; } + + protected abstract string ClientId { get; } + protected abstract string ClientSecret { get; } + + protected abstract string RedirectUri { get; } + + protected OAuthClient(Logger logger) + { + this.Logger = logger; + } + + public virtual void Initialize() + { + this.Client = new HttpClient(); + + this.Client.DefaultRequestHeaders.Accept.Clear(); + // Explicitly mark that we want JSON responses, as some servers (GitHub) will return URL encoded text instead by default + this.Client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + this.Client.DefaultRequestHeaders.UserAgent.Clear(); + // Default user agent header + this.Client.DefaultRequestHeaders.UserAgent.Add(new ProductInfoHeaderValue("Refresher", VersionInformation.Version)); + } + + /// + /// Constructs a URL to send a user to, which they use to authorize Refresh + /// + /// The `state` parameter of the authorization + /// The authorization URL to redirect the user to + public abstract string GetOAuthAuthorizationUrl(string state); + + /// + /// Acquires an access and refresh token using the provided authorization code + /// + /// The authorization code + /// The acquired access and refresh tokens + public OAuth2AccessTokenResponse AcquireTokenFromAuthorizationCode(string authCode) + { + HttpResponseMessage result = this.Client.PostAsync(this.TokenEndpoint, new FormUrlEncodedContent([ + new KeyValuePair("grant_type", "authorization_code"), + new KeyValuePair("code", authCode), + new KeyValuePair("redirect_uri", this.RedirectUri), + new KeyValuePair("client_id", this.ClientId), + new KeyValuePair("client_secret", this.ClientSecret), + ])).Result; + + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + if (result.StatusCode == BadRequest) + { + OAuth2ErrorResponse errorResponse = result.Content.ReadAsJson()!; + + throw new Exception($"Unexpected error {errorResponse.Error} when acquiring token! Description: {errorResponse.ErrorDescription}, URI: {errorResponse.ErrorUri}"); + } + + if (!result.IsSuccessStatusCode) + throw new Exception($"Acquiring token failed, server returned status code {result.StatusCode}"); + + if ((result.Content.Headers.ContentLength ?? 0) == 0) + throw new Exception("Acquiring token failed, request returned no response"); + + OAuth2AccessTokenResponse response = result.Content.ReadAsJson()!; + + // Case insensitive according to spec + if (!response.TokenType.Equals("bearer", StringComparison.InvariantCultureIgnoreCase)) + throw new Exception("Non-bearer tokens are currently unsupported."); + + return response; + } + + /// + /// Refreshes the passed OAuthTokenRelation, acquiring a new access token + /// + /// The database context associated with the token + /// The token to refresh + /// The time provider associated with the request + /// Whether the refresh succeeded, if failed, assume the token is invalid and authorization has been revoked. + public bool RefreshToken(GameDatabaseContext database, OAuthTokenRelation token, IDateTimeProvider timeProvider) + { + // If we have no refresh token, then always fail the token refresh + if (token.RefreshToken == null) + return false; + + HttpResponseMessage result = this.Client.PostAsync(this.TokenEndpoint, new FormUrlEncodedContent([ + new KeyValuePair("grant_type", "refresh_token"), + new KeyValuePair("refresh_token", token.RefreshToken), + new KeyValuePair("client_id", this.ClientId), + new KeyValuePair("client_secret", this.ClientSecret), + ])).Result; + + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + if (result.StatusCode == BadRequest) + { + OAuth2ErrorResponse errorResponse = result.Content.ReadAsJson()!; + + //Special cased error for when the refresh token is invalid + if (errorResponse.Error == "invalid_grant") + return false; + + throw new Exception($"Unexpected error {errorResponse.Error} when refreshing token! Description: {errorResponse.ErrorDescription}, URI: {errorResponse.ErrorUri}"); + } + + if (!result.IsSuccessStatusCode) + throw new Exception($"Refreshing token failed, server returned status code {result.StatusCode}"); + + if ((result.Content.Headers.ContentLength ?? 0) == 0) + throw new Exception("Refreshing token failed, request returned no response"); + + OAuth2AccessTokenResponse response = result.Content.ReadAsJson()!; + + database.UpdateOAuthToken(token, response, timeProvider); + + return true; + } + + /// + /// Revokes the OAuth token + /// + /// + /// + /// + /// + /// + public virtual void RevokeToken(GameDatabaseContext database, OAuthTokenRelation token) + { + if (!this.TokenRevocationSupported) + throw new NotSupportedException("Revocation is not supported by this OAuth client!"); + + HttpResponseMessage result = this.Client.PostAsync(this.TokenRevocationEndpoint, new FormUrlEncodedContent([ + // Not all services use refresh tokens, so if we do not have one, we need to fall back + new KeyValuePair("token", token.RefreshToken ?? token.AccessToken), + new KeyValuePair("token_type_hint", token.RefreshToken == null ? "access_token" : "refresh_token"), + new KeyValuePair("client_id", this.ClientId), + new KeyValuePair("client_secret", this.ClientSecret), + ])).Result; + + // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + if (result.StatusCode == BadRequest) + { + OAuth2ErrorResponse errorResponse = result.Content.ReadAsJson()!; + + throw new Exception($"Unexpected error {errorResponse.Error} when revoking token! Description: {errorResponse.ErrorDescription}, URI: {errorResponse.ErrorUri}"); + } + + // NOTE: As per https://datatracker.ietf.org/doc/html/rfc7009#autoid-5, revocation of an invalid token returns a 200 OK response, so any other response code is unexpected + if (result.StatusCode != OK) + throw new Exception($"Failed to revoke OAuth token, got status code {result.StatusCode}"); + + database.RevokeOAuthToken(token); + } + + /// + /// Makes a request, automatically attempting to refresh the token if applicable + /// + /// The token to authenticate the request + /// A function used to acquire a HttpRequestMessage instance for this particular request + /// The database to use to revoke/update tokens + /// The time provider for the request + /// The response message from the server + protected HttpResponseMessage? MakeRequest(OAuthTokenRelation token, Func getRequest, GameDatabaseContext database, IDateTimeProvider timeProvider) + { + // If we have passed the revocation date and refreshing the token fails, remove the token from the database and bail out + if (timeProvider.Now > token.AccessTokenRevocationTime && !this.RefreshToken(database, token, timeProvider)) + { + database.RevokeOAuthToken(token); + return null; + } + + HttpRequestMessage request = getRequest(); + HttpResponseMessage response = this.Client.Send(request); + + // Technically the specification does not specify what error response the server sends, + // however Unauthorized is the only one which actually makes sense given the context + if (response.StatusCode == Unauthorized) + { + // If we succeeded at refreshing the token, then try to make the request again + if (this.RefreshToken(database, token, timeProvider)) + return this.Client.Send(getRequest()); + + // If we failed, just revoke the token and bail out + database.RevokeOAuthToken(token); + return null; + } + + return response; + } + + protected HttpRequestMessage CreateRequestMessage(OAuthTokenRelation token, HttpMethod method, string uri) + => this.CreateRequestMessage(token, method, new Uri(uri)); + + protected virtual HttpRequestMessage CreateRequestMessage(OAuthTokenRelation token, HttpMethod method, Uri uri) + { + HttpRequestMessage message = new(method, uri); + message.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token.AccessToken); + + return message; + } + + public void Dispose() + { + this.Client.Dispose(); + } +} diff --git a/Refresh.GameServer/Services/OAuth/OAuthService.cs b/Refresh.GameServer/Services/OAuth/OAuthService.cs new file mode 100644 index 00000000..b969d856 --- /dev/null +++ b/Refresh.GameServer/Services/OAuth/OAuthService.cs @@ -0,0 +1,58 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using Bunkum.Core.Services; +using NotEnoughLogs; +using Refresh.GameServer.Configuration; +using Refresh.GameServer.Services.OAuth.Clients; +using Refresh.GameServer.Types.OAuth; + +namespace Refresh.GameServer.Services.OAuth; + +public class OAuthService : EndpointService +{ + private readonly IntegrationConfig _integrationConfig; + + private readonly Dictionary _clients; + + public OAuthService(Logger logger, IntegrationConfig integrationConfig) : base(logger) + { + this._integrationConfig = integrationConfig; + this._clients = new Dictionary(); + } + + public override void Initialize() + { + base.Initialize(); + + if (this._integrationConfig.DiscordOAuthEnabled) + this._clients[OAuthProvider.Discord] = new DiscordOAuthClient(this.Logger, this._integrationConfig); + if (this._integrationConfig.GitHubOAuthEnabled) + this._clients[OAuthProvider.GitHub] = new GitHubOAuthClient(this.Logger, this._integrationConfig); + + // Initialize all the OAuth clients + foreach ((OAuthProvider provider, OAuthClient? client) in this._clients) + { + this.Logger.LogInfo(RefreshContext.Startup, "Initializing {0} OAuth client", provider); + client.Initialize(); + } + } + + public bool GetOAuthClient(OAuthProvider provider, [MaybeNullWhen(false)] out OAuthClient client) => this._clients.TryGetValue(provider, out client); + public bool GetOAuthClient(OAuthProvider provider, [MaybeNullWhen(false)] out T client) where T : class + { + bool ret = this._clients.TryGetValue(provider, out OAuthClient? rawClient); + + if(rawClient != null) + Debug.Assert(rawClient.GetType().IsAssignableTo(typeof(T)), "Acquired client must be assignable to type parameter"); + + client = rawClient as T; + + return ret; + } + + public T? GetOAuthClient(OAuthProvider provider) where T : class + => this._clients.GetValueOrDefault(provider) as T; + + public OAuthClient? GetOAuthClient(OAuthProvider provider) + => this._clients.GetValueOrDefault(provider); +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/Data/DataContext.cs b/Refresh.GameServer/Types/Data/DataContext.cs index 831be29f..384ee3c3 100644 --- a/Refresh.GameServer/Types/Data/DataContext.cs +++ b/Refresh.GameServer/Types/Data/DataContext.cs @@ -3,6 +3,8 @@ using Refresh.GameServer.Authentication; using Refresh.GameServer.Database; using Refresh.GameServer.Services; +using Refresh.GameServer.Services.OAuth; +using Refresh.GameServer.Time; using Refresh.GameServer.Types.UserData; namespace Refresh.GameServer.Types.Data; @@ -14,6 +16,8 @@ public class DataContext public required IDataStore DataStore; public required MatchService Match; public required GuidCheckerService GuidChecker; + public required IDateTimeProvider TimeProvider; + public required OAuthService OAuth; public required Token? Token; public GameUser? User => this.Token?.User; diff --git a/Refresh.GameServer/Types/Data/DataContextService.cs b/Refresh.GameServer/Types/Data/DataContextService.cs index 45d132df..fae9828a 100644 --- a/Refresh.GameServer/Types/Data/DataContextService.cs +++ b/Refresh.GameServer/Types/Data/DataContextService.cs @@ -7,6 +7,8 @@ using Refresh.GameServer.Authentication; using Refresh.GameServer.Database; using Refresh.GameServer.Services; +using Refresh.GameServer.Services.OAuth; +using Refresh.GameServer.Time; namespace Refresh.GameServer.Types.Data; @@ -16,15 +18,21 @@ public class DataContextService : Service private readonly MatchService _matchService; private readonly AuthenticationService _authService; private readonly GuidCheckerService _guidCheckerService; + private readonly TimeProviderService _timeProviderService; + private readonly OAuthService _oAuthService; - public DataContextService(StorageService storage, MatchService match, AuthenticationService auth, Logger logger, GuidCheckerService guidChecker) : base(logger) + public DataContextService(StorageService storage, MatchService match, AuthenticationService auth, Logger logger, + GuidCheckerService guidChecker, TimeProviderService timeProviderService, + OAuthService oAuthService) : base(logger) { this._storageService = storage; this._matchService = match; this._authService = auth; this._guidCheckerService = guidChecker; + this._timeProviderService = timeProviderService; + this._oAuthService = oAuthService; } - + private static T StealInjection(Service service, ListenerContext? context = null, Lazy? database = null, string name = "") { return (T)service.AddParameterToEndpoint(context!, new BunkumParameterInfo(typeof(T), name), database!)!; @@ -42,6 +50,8 @@ private static T StealInjection(Service service, ListenerContext? context = n Match = this._matchService, Token = StealInjection(this._authService, context, database), GuidChecker = this._guidCheckerService, + TimeProvider = StealInjection(this._timeProviderService), + OAuth = StealInjection(this._oAuthService), }; } diff --git a/Refresh.GameServer/Types/OAuth/Discord/DiscordApiUserResponse.cs b/Refresh.GameServer/Types/OAuth/Discord/DiscordApiUserResponse.cs new file mode 100644 index 00000000..e03d1e99 --- /dev/null +++ b/Refresh.GameServer/Types/OAuth/Discord/DiscordApiUserResponse.cs @@ -0,0 +1,34 @@ +namespace Refresh.GameServer.Types.OAuth.Discord.Api; + +[JsonObject(NamingStrategyType = typeof(SnakeCaseNamingStrategy))] +public class DiscordApiUserResponse +{ + /// + /// The user's ID, as a snowflake + /// + public ulong Id { get; set; } + /// + /// The user's username + /// + public string Username { get; set; } + /// + /// The user's discord tag + /// + public string Discriminator { get; set; } + /// + /// The user's global name, if set + /// + public string? GlobalName { get; set; } + /// + /// The hash of the user's avatar + /// + public string? Avatar { get; set; } + /// + /// The hash of the user's banner + /// + public string? Banner { get; set; } + /// + /// The user's accent colour + /// + public uint? AccentColor { get; set; } +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/OAuth/GitHub/GitHubApiUserResponse.cs b/Refresh.GameServer/Types/OAuth/GitHub/GitHubApiUserResponse.cs new file mode 100644 index 00000000..1b80a1af --- /dev/null +++ b/Refresh.GameServer/Types/OAuth/GitHub/GitHubApiUserResponse.cs @@ -0,0 +1,112 @@ +namespace Refresh.GameServer.Types.OAuth.GitHub; + +#nullable disable + +public class GitHubApiUserResponse +{ + [JsonProperty("avatar_url")] public Uri AvatarUrl { get; set; } + + [JsonProperty("bio")] public string Bio { get; set; } + + [JsonProperty("blog")] public string Blog { get; set; } + + [JsonProperty("business_plus", NullValueHandling = NullValueHandling.Ignore)] + public bool? BusinessPlus { get; set; } + + [JsonProperty("collaborators", NullValueHandling = NullValueHandling.Ignore)] + public long? Collaborators { get; set; } + + [JsonProperty("company")] public string Company { get; set; } + + [JsonProperty("created_at")] public DateTimeOffset CreatedAt { get; set; } + + [JsonProperty("disk_usage", NullValueHandling = NullValueHandling.Ignore)] + public long? DiskUsage { get; set; } + + [JsonProperty("email")] public string Email { get; set; } + + [JsonProperty("events_url")] public string EventsUrl { get; set; } + + [JsonProperty("followers")] public long Followers { get; set; } + + [JsonProperty("followers_url")] public Uri FollowersUrl { get; set; } + + [JsonProperty("following")] public long Following { get; set; } + + [JsonProperty("following_url")] public string FollowingUrl { get; set; } + + [JsonProperty("gists_url")] public string GistsUrl { get; set; } + + [JsonProperty("gravatar_id")] public string GravatarId { get; set; } + + [JsonProperty("hireable")] public bool? Hireable { get; set; } + + [JsonProperty("html_url")] public Uri HtmlUrl { get; set; } + + [JsonProperty("id")] public long Id { get; set; } + + [JsonProperty("ldap_dn", NullValueHandling = NullValueHandling.Ignore)] + public string LdapDn { get; set; } + + [JsonProperty("location")] public string Location { get; set; } + + [JsonProperty("login")] public string Login { get; set; } + + [JsonProperty("name")] public string Name { get; set; } + + [JsonProperty("node_id")] public string NodeId { get; set; } + + [JsonProperty("notification_email")] public string NotificationEmail { get; set; } + + [JsonProperty("organizations_url")] public Uri OrganizationsUrl { get; set; } + + [JsonProperty("owned_private_repos", NullValueHandling = NullValueHandling.Ignore)] + public long? OwnedPrivateRepos { get; set; } + + [JsonProperty("plan", NullValueHandling = NullValueHandling.Ignore)] + public GitHubApiPlanResponse Plan { get; set; } + + [JsonProperty("private_gists", NullValueHandling = NullValueHandling.Ignore)] + public long? PrivateGists { get; set; } + + [JsonProperty("public_gists")] public long PublicGists { get; set; } + + [JsonProperty("public_repos")] public long PublicRepos { get; set; } + + [JsonProperty("received_events_url")] public Uri ReceivedEventsUrl { get; set; } + + [JsonProperty("repos_url")] public Uri ReposUrl { get; set; } + + [JsonProperty("site_admin")] public bool SiteAdmin { get; set; } + + [JsonProperty("starred_url")] public string StarredUrl { get; set; } + + [JsonProperty("subscriptions_url")] public Uri SubscriptionsUrl { get; set; } + + [JsonProperty("suspended_at")] public DateTimeOffset? SuspendedAt { get; set; } + + [JsonProperty("total_private_repos", NullValueHandling = NullValueHandling.Ignore)] + public long? TotalPrivateRepos { get; set; } + + [JsonProperty("twitter_username")] public string TwitterUsername { get; set; } + + [JsonProperty("two_factor_authentication", NullValueHandling = NullValueHandling.Ignore)] + public bool? TwoFactorAuthentication { get; set; } + + [JsonProperty("type")] public string Type { get; set; } + + [JsonProperty("updated_at")] public DateTimeOffset UpdatedAt { get; set; } + + [JsonProperty("url")] public Uri Url { get; set; } + + public class GitHubApiPlanResponse + { + [JsonProperty("collaborators")] public long Collaborators { get; set; } + + [JsonProperty("name")] public string Name { get; set; } + + [JsonProperty("private_repos")] public long PrivateRepos { get; set; } + + [JsonProperty("space")] public long Space { get; set; } + } +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/OAuth/OAuth2AccessTokenResponse.cs b/Refresh.GameServer/Types/OAuth/OAuth2AccessTokenResponse.cs new file mode 100644 index 00000000..b7146278 --- /dev/null +++ b/Refresh.GameServer/Types/OAuth/OAuth2AccessTokenResponse.cs @@ -0,0 +1,15 @@ +using JetBrains.Annotations; + +namespace Refresh.GameServer.Types.OAuth; + +#nullable disable + +[JsonObject(NamingStrategyType = typeof(SnakeCaseNamingStrategy))] +public class OAuth2AccessTokenResponse +{ + public string AccessToken { get; set; } + public string TokenType { get; set; } + public double? ExpiresIn { get; set; } + [CanBeNull] public string RefreshToken { get; set; } + [CanBeNull] public string Scope { get; set; } +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/OAuth/OAuth2ErrorResponse.cs b/Refresh.GameServer/Types/OAuth/OAuth2ErrorResponse.cs new file mode 100644 index 00000000..ad6ae510 --- /dev/null +++ b/Refresh.GameServer/Types/OAuth/OAuth2ErrorResponse.cs @@ -0,0 +1,13 @@ +using JetBrains.Annotations; + +namespace Refresh.GameServer.Types.OAuth; + +#nullable disable + +[JsonObject(NamingStrategyType = typeof(SnakeCaseNamingStrategy))] +public class OAuth2ErrorResponse +{ + public string Error { get; set; } + [CanBeNull] public string ErrorDescription { get; set; } + [CanBeNull] public string ErrorUri { get; set; } +} diff --git a/Refresh.GameServer/Types/OAuth/OAuthProvider.cs b/Refresh.GameServer/Types/OAuth/OAuthProvider.cs new file mode 100644 index 00000000..eda2a915 --- /dev/null +++ b/Refresh.GameServer/Types/OAuth/OAuthProvider.cs @@ -0,0 +1,7 @@ +namespace Refresh.GameServer.Types.OAuth; + +public enum OAuthProvider +{ + Discord, + GitHub, +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/OAuth/OAuthRequest.cs b/Refresh.GameServer/Types/OAuth/OAuthRequest.cs new file mode 100644 index 00000000..45c9fcd4 --- /dev/null +++ b/Refresh.GameServer/Types/OAuth/OAuthRequest.cs @@ -0,0 +1,22 @@ +using Realms; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Types.OAuth; + +public partial class OAuthRequest : IRealmObject +{ + [PrimaryKey] + public string State { get; set; } + public GameUser User { get; set; } + public DateTimeOffset ExpiresAt { get; set; } + + [Ignored] + public OAuthProvider Provider + { + get => (OAuthProvider)this._Provider; + set => this._Provider = (int)value; + } + + // ReSharper disable once InconsistentNaming + public int _Provider { get; set; } +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/OAuth/OAuthTokenRelation.cs b/Refresh.GameServer/Types/OAuth/OAuthTokenRelation.cs new file mode 100644 index 00000000..d92e890e --- /dev/null +++ b/Refresh.GameServer/Types/OAuth/OAuthTokenRelation.cs @@ -0,0 +1,35 @@ +using JetBrains.Annotations; +using Realms; +using Refresh.GameServer.Types.UserData; + +namespace Refresh.GameServer.Types.OAuth; + +#nullable disable + +public partial class OAuthTokenRelation : IRealmObject +{ + public GameUser User { get; set; } + + [Ignored] + public OAuthProvider Provider + { + get => (OAuthProvider)this._Provider; + set => this._Provider = (int)value; + } + + // ReSharper disable once InconsistentNaming + public int _Provider { get; set; } + + /// + /// The user's access token + /// + public string AccessToken { get; set; } + /// + /// The time the access token gets revoked + /// + public DateTimeOffset AccessTokenRevocationTime { get; set; } + /// + /// The refresh token used to get a new access token + /// + [CanBeNull] public string RefreshToken { get; set; } +} \ No newline at end of file diff --git a/Refresh.GameServer/Types/UserData/GameUser.cs b/Refresh.GameServer/Types/UserData/GameUser.cs index 52d84e35..7340fdcd 100644 --- a/Refresh.GameServer/Types/UserData/GameUser.cs +++ b/Refresh.GameServer/Types/UserData/GameUser.cs @@ -78,9 +78,6 @@ public partial class GameUser : IRealmObject, IRateLimitUser public bool RpcnAuthenticationAllowed { get; set; } public bool PsnAuthenticationAllowed { get; set; } - private int _ProfileVisibility { get; set; } = (int)Visibility.All; - private int _LevelVisibility { get; set; } = (int)Visibility.All; - /// /// The auth token the presence server knows this user by, null if not connected to the presence server /// @@ -92,6 +89,11 @@ public partial class GameUser : IRealmObject, IRateLimitUser /// public GamePlaylist? RootPlaylist { get; set; } + private int _ProfileVisibility { get; set; } = (int)Visibility.All; + private int _LevelVisibility { get; set; } = (int)Visibility.All; + private int _DiscordProfileVisibility { get; set; } = (int)Visibility.LoggedInUsers; + private int _GitHubProfileVisibility { get; set; } = (int)Visibility.LoggedInUsers; + /// /// Whether the user's profile information is exposed in the public API. /// @@ -111,7 +113,27 @@ public Visibility LevelVisibility get => (Visibility)this._LevelVisibility; set => this._LevelVisibility = (int)value; } + + /// + /// Whether the user's discord profile is exposed in the public API + /// + [Ignored] + public Visibility DiscordProfileVisibility + { + get => (Visibility)this._DiscordProfileVisibility; + set => this._DiscordProfileVisibility = (int)value; + } + /// + /// Whether the user's discord profile is exposed in the public API + /// + [Ignored] + public Visibility GitHubProfileVisibility + { + get => (Visibility)this._GitHubProfileVisibility; + set => this._GitHubProfileVisibility = (int)value; + } + /// /// If `true`, unescape XML tags sent to /filter /// diff --git a/Refresh.GameServer/Types/Visibility.cs b/Refresh.GameServer/Types/Visibility.cs index 44674418..04bff4c9 100644 --- a/Refresh.GameServer/Types/Visibility.cs +++ b/Refresh.GameServer/Types/Visibility.cs @@ -1,5 +1,8 @@ using System.Xml.Serialization; +using Refresh.GameServer.Authentication; using Refresh.GameServer.Endpoints.Game.Handshake; +using Refresh.GameServer.Types.Data; +using Refresh.GameServer.Types.UserData; namespace Refresh.GameServer.Types; @@ -10,18 +13,45 @@ namespace Refresh.GameServer.Types; public enum Visibility { /// - /// User is okay with content being shown everywhere + /// User is okay with content being shown everywhere at all times /// [XmlEnum("all")] All = 0, /// - /// User only allows content to be shown in-game and on website + /// User only allows content to be shown in-game and on the website to authenticated viewers /// [XmlEnum("psn")] // Yes it says PSN, but in-game it is described as "users who are logged into PSN on the website" - Website = 1, + LoggedInUsers = 1, /// /// User only allows content to be shown in-game /// [XmlEnum("game")] Game = 2, +} + +public static class VisibilityExtensions +{ + /// + /// Filters the passed object depending on the passed DataContext and intended visibility + /// + /// The intended visibility of the object + /// The intended visibility of the object + /// The data context of the request asking for the object + /// The object to filter + /// The filtered object + public static T? Filter(this Visibility visibility, GameUser owner, DataContext dataContext, T? obj) where T : class + { + if (dataContext.User?.UserId == owner.UserId) + return obj; + + switch (visibility) + { + case Visibility.Game when dataContext.Game != TokenGame.Website: + case Visibility.LoggedInUsers when dataContext.Token != null: + case Visibility.All: + return obj; + default: + return null; + } + } } \ No newline at end of file diff --git a/Refresh.GameServer/Workers/ExpiredObjectWorker.cs b/Refresh.GameServer/Workers/ExpiredObjectWorker.cs index 126fe25b..9445cf9f 100644 --- a/Refresh.GameServer/Workers/ExpiredObjectWorker.cs +++ b/Refresh.GameServer/Workers/ExpiredObjectWorker.cs @@ -2,6 +2,7 @@ using NotEnoughLogs; using Refresh.GameServer.Authentication; using Refresh.GameServer.Database; +using Refresh.GameServer.Time; using Refresh.GameServer.Types.Data; using Refresh.GameServer.Types.UserData; @@ -35,5 +36,9 @@ public void DoWork(DataContext context) context.Logger.LogInfo(RefreshContext.Worker, $"Removed {token.User}'s {token.TokenType} token since it has expired {DateTimeOffset.Now - token.ExpiresAt} ago"); context.Database.RevokeToken(token); } + + int expiredOAuthRequests = context.Database.RemoveAllExpiredOAuthRequests(context.TimeProvider); + if(expiredOAuthRequests > 0) + context.Logger.LogInfo(RefreshContext.Worker, "Removed {0} OAuth requests, since they have expired", expiredOAuthRequests); } } \ No newline at end of file diff --git a/Refresh.GameServer/Workers/WorkerManager.cs b/Refresh.GameServer/Workers/WorkerManager.cs index 6f88ebdb..fab06793 100644 --- a/Refresh.GameServer/Workers/WorkerManager.cs +++ b/Refresh.GameServer/Workers/WorkerManager.cs @@ -2,6 +2,8 @@ using NotEnoughLogs; using Refresh.GameServer.Database; using Refresh.GameServer.Services; +using Refresh.GameServer.Services.OAuth; +using Refresh.GameServer.Time; using Refresh.GameServer.Types.Data; namespace Refresh.GameServer.Workers; @@ -13,14 +15,20 @@ public class WorkerManager private readonly GameDatabaseProvider _databaseProvider; private readonly MatchService _matchService; private readonly GuidCheckerService _guidCheckerService; + private readonly IDateTimeProvider _timeProvider; + private readonly OAuthService _oAuthService; - public WorkerManager(Logger logger, IDataStore dataStore, GameDatabaseProvider databaseProvider, MatchService matchService, GuidCheckerService guidCheckerService) + public WorkerManager(Logger logger, IDataStore dataStore, GameDatabaseProvider databaseProvider, + MatchService matchService, GuidCheckerService guidCheckerService, IDateTimeProvider timeProvider, + OAuthService oAuthService) { this._dataStore = dataStore; this._databaseProvider = databaseProvider; this._logger = logger; this._matchService = matchService; this._guidCheckerService = guidCheckerService; + this._timeProvider = timeProvider; + this._oAuthService = oAuthService; } private Thread? _thread = null; @@ -49,6 +57,8 @@ private void RunWorkCycle() Match = this._matchService, Token = null, GuidChecker = this._guidCheckerService, + TimeProvider = this._timeProvider, + OAuth = this._oAuthService, }); foreach (IWorker worker in this._workers) diff --git a/RefreshTests.GameServer/Extensions/HttpContentExtensions.cs b/RefreshTests.GameServer/Extensions/HttpContentExtensions.cs deleted file mode 100644 index e0a04e36..00000000 --- a/RefreshTests.GameServer/Extensions/HttpContentExtensions.cs +++ /dev/null @@ -1,14 +0,0 @@ -using System.Xml; -using System.Xml.Serialization; - -namespace RefreshTests.GameServer.Extensions; - -public static class HttpContentExtensions -{ - public static T ReadAsXML(this HttpContent content) - { - XmlSerializer serializer = new(typeof(T)); - - return (T)serializer.Deserialize(new XmlTextReader(content.ReadAsStream()))!; - } -} \ No newline at end of file diff --git a/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs b/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs index e25ca416..0606c1d1 100644 --- a/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs +++ b/RefreshTests.GameServer/GameServer/TestRefreshGameServer.cs @@ -11,6 +11,7 @@ using Refresh.GameServer.Configuration; using Refresh.GameServer.Database; using Refresh.GameServer.Services; +using Refresh.GameServer.Services.OAuth; using Refresh.GameServer.Time; using Refresh.GameServer.Types.Data; using Refresh.GameServer.Types.Levels.Categories; @@ -76,6 +77,7 @@ protected override void SetupServices() this.Server.AddService(); this.Server.AddService(); this.Server.AddService(); + this.Server.AddService(); // Must always be last, see comment in RefreshGameServer this.Server.AddService(); diff --git a/RefreshTests.GameServer/TestContext.cs b/RefreshTests.GameServer/TestContext.cs index 3274aead..a59a9dd2 100644 --- a/RefreshTests.GameServer/TestContext.cs +++ b/RefreshTests.GameServer/TestContext.cs @@ -6,8 +6,7 @@ using Refresh.GameServer.Authentication; using Refresh.GameServer.Database; using Refresh.GameServer.Services; -using Refresh.GameServer.Types; -using Refresh.GameServer.Types.Contests; +using Refresh.GameServer.Services.OAuth; using Refresh.GameServer.Types.Data; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Roles; @@ -165,6 +164,8 @@ public DataContext GetDataContext(Token? token = null) Match = this.GetService(), Token = token, GuidChecker = this.GetService(), + TimeProvider = this.Time, + OAuth = this.GetService(), }; } diff --git a/RefreshTests.GameServer/Tests/Assets/AssetUploadTests.cs b/RefreshTests.GameServer/Tests/Assets/AssetUploadTests.cs index 212b1b8a..ef6d657e 100644 --- a/RefreshTests.GameServer/Tests/Assets/AssetUploadTests.cs +++ b/RefreshTests.GameServer/Tests/Assets/AssetUploadTests.cs @@ -1,4 +1,5 @@ using System.Security.Cryptography; +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Configuration; using Refresh.GameServer.Services; @@ -415,7 +416,7 @@ public void CheckForMissingAssets(bool psp) HttpResponseMessage response = client.PostAsync("/lbp/filterResources", new StringContent(new SerializedResourceList(new[] { hash }).AsXML())).Result; Assert.That(response.StatusCode, Is.EqualTo(OK)); - SerializedResourceList missingList = response.Content.ReadAsXML(); + SerializedResourceList missingList = response.Content.ReadAsXml(); Assert.That(missingList.Items, Has.Count.EqualTo(1)); Assert.That(missingList.Items[0], Is.EqualTo(hash)); @@ -427,7 +428,7 @@ public void CheckForMissingAssets(bool psp) response = client.PostAsync("/lbp/filterResources", new StringContent(new SerializedResourceList(new[] { hash }).AsXML())).Result; Assert.That(response.StatusCode, Is.EqualTo(OK)); - missingList = response.Content.ReadAsXML(); + missingList = response.Content.ReadAsXml(); Assert.That(missingList.Items, Has.Count.EqualTo(0)); } diff --git a/RefreshTests.GameServer/Tests/Comments/CommentTests.cs b/RefreshTests.GameServer/Tests/Comments/CommentTests.cs index deacd791..5005f494 100644 --- a/RefreshTests.GameServer/Tests/Comments/CommentTests.cs +++ b/RefreshTests.GameServer/Tests/Comments/CommentTests.cs @@ -1,4 +1,5 @@ -using Refresh.GameServer.Types.Comments; +using Refresh.Common.Extensions; +using Refresh.GameServer.Types.Comments; using Refresh.GameServer.Types.Lists; using Refresh.GameServer.Types.Reviews; using Refresh.GameServer.Types.UserData; @@ -23,7 +24,7 @@ public static void RateComment(TestContext context, GameUser user, IGameComment } HttpResponseMessage response = client.GetAsync(getCommentsUrl).Result; - SerializedCommentList userComments = response.Content.ReadAsXML(); + SerializedCommentList userComments = response.Content.ReadAsXml(); SerializedComment serializedComment = userComments.Items.First(); int expectedThumbsUp, expectedThumbsDown; diff --git a/RefreshTests.GameServer/Tests/Comments/LevelCommentTests.cs b/RefreshTests.GameServer/Tests/Comments/LevelCommentTests.cs index 86639eaa..48165132 100644 --- a/RefreshTests.GameServer/Tests/Comments/LevelCommentTests.cs +++ b/RefreshTests.GameServer/Tests/Comments/LevelCommentTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Comments; using Refresh.GameServer.Types.Levels; @@ -30,7 +31,7 @@ public void PostAndDeleteLevelComment() Assert.That(response.StatusCode, Is.EqualTo(OK)); response = client.GetAsync($"/lbp/comments/user/{level.LevelId}").Result; - SerializedCommentList userComments = response.Content.ReadAsXML(); + SerializedCommentList userComments = response.Content.ReadAsXml(); Assert.That(userComments.Items, Has.Count.EqualTo(1)); Assert.That(userComments.Items[0].Content, Is.EqualTo(comment.Content)); @@ -38,7 +39,7 @@ public void PostAndDeleteLevelComment() Assert.That(response.StatusCode, Is.EqualTo(OK)); response = client.GetAsync($"/lbp/comments/user/{level.LevelId}").Result; - userComments = response.Content.ReadAsXML(); + userComments = response.Content.ReadAsXml(); Assert.That(userComments.Items, Has.Count.EqualTo(0)); } @@ -156,7 +157,7 @@ public void CantDeleteAnotherUsersComment() Assert.That(response.StatusCode, Is.EqualTo(OK)); response = client1.GetAsync($"/lbp/comments/user/{level.LevelId}").Result; - SerializedCommentList userComments = response.Content.ReadAsXML(); + SerializedCommentList userComments = response.Content.ReadAsXml(); Assert.That(userComments.Items, Has.Count.EqualTo(1)); Assert.That(userComments.Items[0].Content, Is.EqualTo(comment.Content)); diff --git a/RefreshTests.GameServer/Tests/Comments/UserCommentTests.cs b/RefreshTests.GameServer/Tests/Comments/UserCommentTests.cs index f062d70e..24cd07a9 100644 --- a/RefreshTests.GameServer/Tests/Comments/UserCommentTests.cs +++ b/RefreshTests.GameServer/Tests/Comments/UserCommentTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Comments; using Refresh.GameServer.Types.Lists; @@ -29,7 +30,7 @@ public void PostAndDeleteUserComment() Assert.That(response.StatusCode, Is.EqualTo(OK)); response = client.GetAsync($"/lbp/userComments/{user2.Username}").Result; - SerializedCommentList userComments = response.Content.ReadAsXML(); + SerializedCommentList userComments = response.Content.ReadAsXml(); Assert.That(userComments.Items, Has.Count.EqualTo(1)); Assert.That(userComments.Items[0].Content, Is.EqualTo(comment.Content)); @@ -37,7 +38,7 @@ public void PostAndDeleteUserComment() Assert.That(response.StatusCode, Is.EqualTo(OK)); response = client.GetAsync($"/lbp/userComments/{user2.Username}").Result; - userComments = response.Content.ReadAsXML(); + userComments = response.Content.ReadAsXml(); Assert.That(userComments.Items, Has.Count.EqualTo(0)); } @@ -152,7 +153,7 @@ public void CantDeleteAnotherUsersComment() Assert.That(response.StatusCode, Is.EqualTo(OK)); response = client1.GetAsync($"/lbp/userComments/{user2.Username}").Result; - SerializedCommentList userComments = response.Content.ReadAsXML(); + SerializedCommentList userComments = response.Content.ReadAsXml(); Assert.That(userComments.Items, Has.Count.EqualTo(1)); Assert.That(userComments.Items[0].Content, Is.EqualTo(comment.Content)); diff --git a/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs b/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs index 93368343..345619ff 100644 --- a/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs +++ b/RefreshTests.GameServer/Tests/Levels/LevelListOverrideTests.cs @@ -1,5 +1,6 @@ using MongoDB.Bson; using NotEnoughLogs; +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Configuration; using Refresh.GameServer.Services; @@ -49,7 +50,7 @@ public void CanGetOverriddenLevels() // This can be any endpoint that doesnt return all levels but I chose mmpicks HttpResponseMessage message = client.GetAsync("/lbp/slots/mmpicks").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList levelList = message.Content.ReadAsXML(); + SerializedMinimalLevelList levelList = message.Content.ReadAsXml(); Assert.That(levelList.Items, Is.Empty); //Make sure we dont have an override set @@ -68,14 +69,14 @@ public void CanGetOverriddenLevels() //Get the slots, and make sure it contains the level we set as the override message = client.GetAsync("/lbp/slots/mmpicks").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - levelList = message.Content.ReadAsXML(); + levelList = message.Content.ReadAsXml(); Assert.That(levelList.Items, Has.Count.EqualTo(1)); Assert.That(levelList.Items[0].LevelId, Is.EqualTo(level.LevelId)); //Verify the team picks slot list has stopped pointing to the user override message = client.GetAsync("/lbp/slots/mmpicks").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - levelList = message.Content.ReadAsXML(); + levelList = message.Content.ReadAsXml(); Assert.That(levelList.Items, Is.Empty); } } \ No newline at end of file diff --git a/RefreshTests.GameServer/Tests/Levels/LevelTests.cs b/RefreshTests.GameServer/Tests/Levels/LevelTests.cs index 3f890613..bdfda5e8 100644 --- a/RefreshTests.GameServer/Tests/Levels/LevelTests.cs +++ b/RefreshTests.GameServer/Tests/Levels/LevelTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Endpoints.Game.DataTypes.Response; using Refresh.GameServer.Types.Levels; @@ -22,7 +23,7 @@ public void SlotsNewest() HttpResponseMessage message = client.GetAsync($"/lbp/slots/newest").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().LevelId, Is.EqualTo(level.LevelId)); @@ -30,7 +31,7 @@ public void SlotsNewest() message = client.GetAsync($"/lbp/slots").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().LevelId, Is.EqualTo(level.LevelId)); } @@ -56,7 +57,7 @@ void TestSeed(GameLevel expectedLevel, int seed) HttpResponseMessage message = client.GetAsync($"/lbp/slots/lbp2luckydip?seed={seed}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(3)); Assert.That(result.Items.First().LevelId, Is.EqualTo(expectedLevel.LevelId)); } @@ -81,7 +82,7 @@ public void SlotsQueued() HttpResponseMessage message = client.GetAsync($"/lbp/slots/lolcatftw").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); //Add the level to the queue @@ -92,7 +93,7 @@ public void SlotsQueued() message = client.GetAsync($"/lbp/slots/lolcatftw").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().LevelId, Is.EqualTo(level.LevelId)); @@ -104,7 +105,7 @@ public void SlotsQueued() message = client.GetAsync($"/lbp/slots/lolcatftw").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -121,7 +122,7 @@ public void SlotsHearted() HttpResponseMessage message = client.GetAsync($"/lbp/slots/favouriteSlots").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure its empty - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); //Favourite the level @@ -132,7 +133,7 @@ public void SlotsHearted() message = client.GetAsync($"/lbp/slots/favouriteSlots").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure the only entry is the level - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().LevelId, Is.EqualTo(level.LevelId)); @@ -144,7 +145,7 @@ public void SlotsHearted() message = client.GetAsync($"/lbp/slots/favouriteSlots").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure its now empty - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -161,7 +162,7 @@ public void SlotsHeartedQuirk() HttpResponseMessage message = client.GetAsync($"/lbp/favouriteSlots/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure its empty - SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXML(); + SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); //Favourite the level @@ -172,7 +173,7 @@ public void SlotsHeartedQuirk() message = client.GetAsync($"/lbp/favouriteSlots/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure the only entry is the level - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().LevelId, Is.EqualTo(level.LevelId)); @@ -184,7 +185,7 @@ public void SlotsHeartedQuirk() message = client.GetAsync($"/lbp/favouriteSlots/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure its now empty - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -217,7 +218,7 @@ public void SlotsMostHearted() HttpResponseMessage message = client.GetAsync($"/lbp/slots/mostHearted").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); context.Database.FavouriteLevel(level, user1); @@ -227,7 +228,7 @@ public void SlotsMostHearted() message = client.GetAsync($"/lbp/slots/mostHearted").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level2.LevelId)); @@ -238,7 +239,7 @@ public void SlotsMostHearted() message = client.GetAsync($"/lbp/slots/mostHearted").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level2.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level.LevelId)); @@ -268,7 +269,7 @@ public void SlotsMostLiked() HttpResponseMessage message = client.GetAsync($"/lbp/slots/highestRated").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); bool a = context.Database.RateLevel(level, user, RatingType.Yay); @@ -278,7 +279,7 @@ public void SlotsMostLiked() message = client.GetAsync($"/lbp/slots/highestRated").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level2.LevelId)); @@ -289,7 +290,7 @@ public void SlotsMostLiked() message = client.GetAsync($"/lbp/slots/highestRated").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level2.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level.LevelId)); @@ -312,7 +313,7 @@ public void SlotsMostPlayed() HttpResponseMessage message = client.GetAsync($"/lbp/slots/mostUniquePlays").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); context.Database.PlayLevel(level, user, 1); @@ -324,7 +325,7 @@ public void SlotsMostPlayed() message = client.GetAsync($"/lbp/slots/mostUniquePlays").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level2.LevelId)); @@ -347,7 +348,7 @@ public void SlotsMostReplayed() HttpResponseMessage message = client.GetAsync($"/lbp/slots/mostPlays").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); context.Database.PlayLevel(level, user, 1); @@ -359,7 +360,7 @@ public void SlotsMostReplayed() message = client.GetAsync($"/lbp/slots/mostPlays").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level2.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level.LevelId)); @@ -377,7 +378,7 @@ public void SlotsTeamPicked() HttpResponseMessage message = client.GetAsync($"/lbp/slots/mmpicks").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); context.Database.AddTeamPickToLevel(level); @@ -385,7 +386,7 @@ public void SlotsTeamPicked() message = client.GetAsync($"/lbp/slots/mmpicks").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level.LevelId)); } @@ -401,7 +402,7 @@ public void SlotsByUser() HttpResponseMessage message = client.GetAsync($"/lbp/slots/by/{publisher.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelList result = message.Content.ReadAsXML(); + SerializedMinimalLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); GameLevel level = context.CreateLevel(publisher); @@ -409,7 +410,7 @@ public void SlotsByUser() message = client.GetAsync($"/lbp/slots/by/{publisher.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level.LevelId)); @@ -429,7 +430,7 @@ public void GetLevelById() HttpResponseMessage message = client.GetAsync($"/lbp/s/user/{level.LevelId}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - GameLevelResponse result = message.Content.ReadAsXML(); + GameLevelResponse result = message.Content.ReadAsXml(); Assert.That(result.LevelId, Is.EqualTo(level.LevelId)); message = client.GetAsync($"/lbp/s/user/{int.MaxValue}").Result; @@ -449,7 +450,7 @@ public void GetSlotList() HttpResponseMessage message = client.GetAsync($"/lbp/slotList?s={level.LevelId}&s={level2.LevelId}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedLevelList result = message.Content.ReadAsXML(); + SerializedLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level2.LevelId)); @@ -468,7 +469,7 @@ public void GetLevelsFromCategory() HttpResponseMessage message = client.GetAsync($"/lbp/searches/newest").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMinimalLevelResultsList result = message.Content.ReadAsXML(); + SerializedMinimalLevelResultsList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(2)); Assert.That(result.Items[0].LevelId, Is.EqualTo(level.LevelId)); Assert.That(result.Items[1].LevelId, Is.EqualTo(level2.LevelId)); @@ -508,7 +509,7 @@ public void GetSlotListWhenInvalidLevel() HttpResponseMessage message = client.GetAsync($"/lbp/slotList?s={int.MaxValue}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedLevelList result = message.Content.ReadAsXML(); + SerializedLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -524,13 +525,13 @@ public void GetModernCategories() Assert.That(message.StatusCode, Is.EqualTo(OK)); //Just throw away the value, but at least make sure it parses - _ = message.Content.ReadAsXML(); + _ = message.Content.ReadAsXml(); message = client.GetAsync($"/lbp/searches").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Just throw away the value, but at least make sure it parses - _ = message.Content.ReadAsXML(); + _ = message.Content.ReadAsXml(); } [Test] diff --git a/RefreshTests.GameServer/Tests/Levels/PublishEndpointsTests.cs b/RefreshTests.GameServer/Tests/Levels/PublishEndpointsTests.cs index d739f2f6..f7bab7d9 100644 --- a/RefreshTests.GameServer/Tests/Levels/PublishEndpointsTests.cs +++ b/RefreshTests.GameServer/Tests/Levels/PublishEndpointsTests.cs @@ -1,4 +1,5 @@ using Refresh.Common.Constants; +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Database; using Refresh.GameServer.Endpoints.Game.DataTypes.Request; @@ -47,7 +48,7 @@ public void PublishLevel() HttpResponseMessage message = client.PostAsync("/lbp/startPublish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedLevelResources resourcesToUpload = message.Content.ReadAsXML(); + SerializedLevelResources resourcesToUpload = message.Content.ReadAsXml(); Assert.That(resourcesToUpload.Resources, Has.Length.EqualTo(1)); Assert.That(resourcesToUpload.Resources[0], Is.EqualTo(TEST_ASSET_HASH)); @@ -58,7 +59,7 @@ public void PublishLevel() message = client.PostAsync("/lbp/publish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - GameLevelResponse response = message.Content.ReadAsXML(); + GameLevelResponse response = message.Content.ReadAsXml(); Assert.That(response.Title, Is.EqualTo(level.Title)); Assert.That(response.Description, Is.EqualTo(level.Description)); @@ -70,13 +71,13 @@ public void PublishLevel() Assert.That(message.StatusCode, Is.EqualTo(OK)); //Since theres no new assets, the XML deserializer will deserialize the resources list into null - resourcesToUpload = message.Content.ReadAsXML(); + resourcesToUpload = message.Content.ReadAsXml(); Assert.That(resourcesToUpload.Resources, Is.EqualTo(null)); message = client.PostAsync("/lbp/publish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - response = message.Content.ReadAsXML(); + response = message.Content.ReadAsXml(); Assert.That(response.Title, Is.EqualTo(level.Title)); Assert.That(response.Description, Is.EqualTo(level.Description)); } @@ -117,7 +118,7 @@ public void LevelWithLongTitleGetsTruncated() message = client.PostAsync("/lbp/publish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - Assert.That(message.Content.ReadAsXML().Title.Length, Is.EqualTo(UgcLimits.TitleLimit)); + Assert.That(message.Content.ReadAsXml().Title.Length, Is.EqualTo(UgcLimits.TitleLimit)); } [Test] @@ -156,7 +157,7 @@ public void LevelWithLongDescriptionGetsTruncated() message = client.PostAsync("/lbp/publish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - Assert.That(message.Content.ReadAsXML().Description.Length, Is.EqualTo(UgcLimits.DescriptionLimit)); + Assert.That(message.Content.ReadAsXml().Description.Length, Is.EqualTo(UgcLimits.DescriptionLimit)); } [Test] @@ -397,7 +398,7 @@ public void CantRepublishOtherUsersLevel() HttpResponseMessage message = client1.PostAsync("/lbp/startPublish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedLevelResources resourcesToUpload = message.Content.ReadAsXML(); + SerializedLevelResources resourcesToUpload = message.Content.ReadAsXml(); Assert.That(resourcesToUpload.Resources, Has.Length.EqualTo(1)); Assert.That(resourcesToUpload.Resources[0], Is.EqualTo(TEST_ASSET_HASH)); @@ -409,7 +410,7 @@ public void CantRepublishOtherUsersLevel() message = client1.PostAsync("/lbp/publish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - GameLevelResponse response = message.Content.ReadAsXML(); + GameLevelResponse response = message.Content.ReadAsXml(); Assert.That(response.Title, Is.EqualTo(level.Title)); Assert.That(response.Description, Is.EqualTo(level.Description)); @@ -454,7 +455,7 @@ public void CantPublishSameRootLevelHashTwice() HttpResponseMessage message = client1.PostAsync("/lbp/startPublish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedLevelResources resourcesToUpload = message.Content.ReadAsXML(); + SerializedLevelResources resourcesToUpload = message.Content.ReadAsXml(); Assert.That(resourcesToUpload.Resources, Has.Length.EqualTo(1)); Assert.That(resourcesToUpload.Resources[0], Is.EqualTo(TEST_ASSET_HASH)); @@ -466,7 +467,7 @@ public void CantPublishSameRootLevelHashTwice() message = client1.PostAsync("/lbp/publish", new StringContent(level.AsXML())).Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - GameLevelResponse response = message.Content.ReadAsXML(); + GameLevelResponse response = message.Content.ReadAsXml(); Assert.That(response.Title, Is.EqualTo(level.Title)); Assert.That(response.Description, Is.EqualTo(level.Description)); diff --git a/RefreshTests.GameServer/Tests/Levels/ScoreLeaderboardTests.cs b/RefreshTests.GameServer/Tests/Levels/ScoreLeaderboardTests.cs index fc52534f..76c3291c 100644 --- a/RefreshTests.GameServer/Tests/Levels/ScoreLeaderboardTests.cs +++ b/RefreshTests.GameServer/Tests/Levels/ScoreLeaderboardTests.cs @@ -1,4 +1,5 @@ using MongoDB.Bson; +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Lists; @@ -33,7 +34,7 @@ public void SubmitsScore() message = client.GetAsync($"/lbp/topscores/user/{level.LevelId}/1").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedScoreList scores = message.Content.ReadAsXML(); + SerializedScoreList scores = message.Content.ReadAsXml(); Assert.That(scores.Scores, Has.Count.EqualTo(1)); Assert.That(scores.Scores[0].Player, Is.EqualTo(user.Username)); Assert.That(scores.Scores[0].Score, Is.EqualTo(5)); @@ -41,7 +42,7 @@ public void SubmitsScore() message = client.GetAsync($"/lbp/scoreboard/user/{level.LevelId}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMultiLeaderboardResponse scoresMulti = message.Content.ReadAsXML(); + SerializedMultiLeaderboardResponse scoresMulti = message.Content.ReadAsXml(); SerializedPlayerLeaderboardResponse singleplayerScores = scoresMulti.Scoreboards.First(s => s.PlayerCount == 1); Assert.That(singleplayerScores.Scores, Has.Count.EqualTo(1)); Assert.That(singleplayerScores.Scores[0].Player, Is.EqualTo(user.Username)); @@ -71,7 +72,7 @@ public void SubmitsDeveloperScore() message = client.GetAsync($"/lbp/topscores/developer/1/1").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedScoreList scores = message.Content.ReadAsXML(); + SerializedScoreList scores = message.Content.ReadAsXml(); Assert.That(scores.Scores, Has.Count.EqualTo(1)); Assert.That(scores.Scores[0].Player, Is.EqualTo(user.Username)); Assert.That(scores.Scores[0].Score, Is.EqualTo(5)); @@ -79,7 +80,7 @@ public void SubmitsDeveloperScore() message = client.GetAsync($"/lbp/scoreboard/developer/1").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedMultiLeaderboardResponse scoresMulti = message.Content.ReadAsXML(); + SerializedMultiLeaderboardResponse scoresMulti = message.Content.ReadAsXml(); SerializedPlayerLeaderboardResponse singleplayerScores = scoresMulti.Scoreboards.First(s => s.PlayerCount == 1); Assert.That(singleplayerScores.Scores, Has.Count.EqualTo(1)); Assert.That(singleplayerScores.Scores[0].Player, Is.EqualTo(user.Username)); @@ -460,7 +461,7 @@ public async Task GamePaginationSortsCorrectly() context.FillLeaderboard(level, 4, 1); HttpResponseMessage response = await client.GetAsync($"/lbp/topscores/user/{level.LevelId}/1?pageStart=1&pageSize=2"); - SerializedScoreList firstPage = response.Content.ReadAsXML(); + SerializedScoreList firstPage = response.Content.ReadAsXml(); Assert.Multiple(() => { @@ -470,7 +471,7 @@ public async Task GamePaginationSortsCorrectly() }); response = await client.GetAsync($"/lbp/topscores/user/{level.LevelId}/1?pageStart=3&pageSize=2"); - SerializedScoreList secondPage = response.Content.ReadAsXML(); + SerializedScoreList secondPage = response.Content.ReadAsXml(); Assert.Multiple(() => { diff --git a/RefreshTests.GameServer/Tests/Lists/ListTests.cs b/RefreshTests.GameServer/Tests/Lists/ListTests.cs index 1e4951e3..5e0ca92b 100644 --- a/RefreshTests.GameServer/Tests/Lists/ListTests.cs +++ b/RefreshTests.GameServer/Tests/Lists/ListTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Lists; @@ -27,7 +28,7 @@ public async Task LevelListPaginatesCorrectly() while (true) { HttpResponseMessage message = await client.GetAsync($"/lbp/slots/newest?pageStart={(pageSize * page) + 1}&pageSize={pageSize}"); - SerializedMinimalLevelList levelList = message.Content.ReadAsXML(); + SerializedMinimalLevelList levelList = message.Content.ReadAsXml(); if (pageSize * page >= levelList.Total) break; Assert.Multiple(() => diff --git a/RefreshTests.GameServer/Tests/Matching/MatchingTests.cs b/RefreshTests.GameServer/Tests/Matching/MatchingTests.cs index 3211d560..375ac40f 100644 --- a/RefreshTests.GameServer/Tests/Matching/MatchingTests.cs +++ b/RefreshTests.GameServer/Tests/Matching/MatchingTests.cs @@ -17,8 +17,7 @@ public class MatchingTests : GameServerTest public void CreatesRooms() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { @@ -34,24 +33,8 @@ public void CreatesRooms() Token token1 = context.CreateToken(user1); Token token2 = context.CreateToken(user2); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - GuidChecker = null!, - Token = token1, - }, context.Server.Value.GameServerConfig); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token2), context.Server.Value.GameServerConfig); Assert.Multiple(() => { @@ -76,8 +59,7 @@ public void CreatesRooms() public void DoesntMatchIfNoRooms() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { @@ -90,15 +72,7 @@ public void DoesntMatchIfNoRooms() // Setup room GameUser user1 = context.CreateUser(); Token token1 = context.CreateToken(user1); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - GuidChecker = null!, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); // Tell user1 to try to find a room Response response = match.ExecuteMethod("FindBestRoom", new SerializedRoomData @@ -107,15 +81,7 @@ public void DoesntMatchIfNoRooms() { NatType.Open, }, - }, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - GuidChecker = null!, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - }, context.Server.Value.GameServerConfig); + }, context.GetDataContext(token1), context.Server.Value.GameServerConfig); // Deserialize the result List responseObjects = @@ -131,8 +97,7 @@ public void DoesntMatchIfNoRooms() public void StrictNatCantJoinStrict() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { @@ -150,34 +115,11 @@ public void StrictNatCantJoinStrict() Token token1 = context.CreateToken(user1); Token token2 = context.CreateToken(user2); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - GuidChecker = null!, - Token = token1, - }, context.Server.Value.GameServerConfig); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token2), context.Server.Value.GameServerConfig); // Tell user2 to try to find a room - Response response = match.ExecuteMethod("FindBestRoom", new SerializedRoomData { NatType = [NatType.Strict], }, new DataContext { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, - Match = match, - GuidChecker = null!, - Token = token2, - }, context.Server.Value.GameServerConfig); + Response response = match.ExecuteMethod("FindBestRoom", new SerializedRoomData { NatType = [NatType.Strict], }, context.GetDataContext(token2), context.Server.Value.GameServerConfig); //Deserialize the result List responseObjects = @@ -193,8 +135,7 @@ public void StrictNatCantJoinStrict() public void StrictNatCanJoinOpen() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { @@ -221,24 +162,8 @@ public void StrictNatCanJoinOpen() Token token1 = context.CreateToken(user1); Token token2 = context.CreateToken(user2); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); - match.ExecuteMethod("CreateRoom", roomData2, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData2, context.GetDataContext(token2), context.Server.Value.GameServerConfig); // Tell user2 to try to find a room Response response = match.ExecuteMethod("FindBestRoom", new SerializedRoomData @@ -246,15 +171,7 @@ public void StrictNatCanJoinOpen() NatType = new List { NatType.Strict, }, - }, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + }, context.GetDataContext(token2), context.Server.Value.GameServerConfig); Assert.That(response.StatusCode, Is.EqualTo(OK)); } @@ -262,8 +179,7 @@ public void StrictNatCanJoinOpen() public void MatchesPlayersTogether() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { @@ -281,24 +197,8 @@ public void MatchesPlayersTogether() Token token1 = context.CreateToken(user1); Token token2 = context.CreateToken(user2); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token2), context.Server.Value.GameServerConfig); // Tell user2 to try to find a room Response response = match.ExecuteMethod("FindBestRoom", new SerializedRoomData @@ -306,15 +206,7 @@ public void MatchesPlayersTogether() NatType = new List { NatType.Open, }, - }, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + }, context.GetDataContext(token2), context.Server.Value.GameServerConfig); Assert.That(response.StatusCode, Is.EqualTo(OK)); } @@ -322,8 +214,7 @@ public void MatchesPlayersTogether() public void HostCanSetPlayersInRoom() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { @@ -341,24 +232,8 @@ public void HostCanSetPlayersInRoom() Token token1 = context.CreateToken(user1); Token token2 = context.CreateToken(user2); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token2), context.Server.Value.GameServerConfig); // Get user1 and user2 in the same room roomData.Players = new List @@ -367,15 +242,7 @@ public void HostCanSetPlayersInRoom() user2.Username, }; - match.ExecuteMethod("UpdatePlayersInRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("UpdatePlayersInRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); GameRoom? room = match.RoomAccessor.GetRoomByUser(user1); Assert.Multiple(() => { @@ -390,8 +257,7 @@ public void HostCanSetPlayersInRoom() public void PlayersCanLeaveAndSplitIntoNewRoom() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { @@ -409,24 +275,8 @@ public void PlayersCanLeaveAndSplitIntoNewRoom() Token token1 = context.CreateToken(user1); Token token2 = context.CreateToken(user2); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token2), context.Server.Value.GameServerConfig); // Get user1 and user2 in the same room roomData.Players = new List @@ -436,30 +286,14 @@ public void PlayersCanLeaveAndSplitIntoNewRoom() }; { - match.ExecuteMethod("UpdatePlayersInRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token1, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("UpdatePlayersInRoom", roomData, context.GetDataContext(token1), context.Server.Value.GameServerConfig); GameRoom? user1Room = match.RoomAccessor.GetRoomByUser(user1); Assert.That(user1Room, Is.Not.Null); Assert.That(user1Room!.PlayerIds.FirstOrDefault(r => r.Id == user2.UserId), Is.Not.Null); } { - match.ExecuteMethod("CreateRoom", roomData, new DataContext - { - Database = context.Database, - Logger = context.Server.Value.Logger, - DataStore = null!, //this isn't accessed by matching - Match = match, - Token = token2, - GuidChecker = null!, - }, context.Server.Value.GameServerConfig); + match.ExecuteMethod("CreateRoom", roomData, context.GetDataContext(token2), context.Server.Value.GameServerConfig); GameRoom? user1Room = match.RoomAccessor.GetRoomByUser(user1); GameRoom? user2Room = match.RoomAccessor.GetRoomByUser(user2); Assert.That(user1Room, Is.Not.Null); @@ -474,8 +308,7 @@ public void PlayersCanLeaveAndSplitIntoNewRoom() public void DoesntMatchIfLookingForLevelWhenPod() { using TestContext context = this.GetServer(false); - MatchService match = new(Logger); - match.Initialize(); + MatchService match = context.GetService(); SerializedRoomData roomData = new() { diff --git a/RefreshTests.GameServer/Tests/Photos/PhotoEndpointsTests.cs b/RefreshTests.GameServer/Tests/Photos/PhotoEndpointsTests.cs index 0ffbc247..9f3e2dc7 100644 --- a/RefreshTests.GameServer/Tests/Photos/PhotoEndpointsTests.cs +++ b/RefreshTests.GameServer/Tests/Photos/PhotoEndpointsTests.cs @@ -1,4 +1,5 @@ using System.Reflection; +using Refresh.Common.Extensions; using Refresh.Common.Helpers; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Levels; @@ -60,7 +61,7 @@ public void UploadAndDeletePhoto() message = client.GetAsync($"/lbp/photos/by?user={user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedPhotoList response = message.Content.ReadAsXML(); + SerializedPhotoList response = message.Content.ReadAsXml(); Assert.That(response.Items, Has.Count.EqualTo(1)); Assert.That(response.Items[0].LargeHash, Is.EqualTo(TEST_ASSET_HASH)); @@ -68,7 +69,7 @@ public void UploadAndDeletePhoto() message = client.GetAsync($"/lbp/photos/with?user={user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - response = message.Content.ReadAsXML(); + response = message.Content.ReadAsXml(); Assert.That(response.Items, Has.Count.EqualTo(1)); Assert.That(response.Items[0].LargeHash, Is.EqualTo(TEST_ASSET_HASH)); @@ -80,7 +81,7 @@ public void UploadAndDeletePhoto() message = client.GetAsync($"/lbp/photos/by?user={user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - response = message.Content.ReadAsXML(); + response = message.Content.ReadAsXml(); Assert.That(response.Items, Has.Count.EqualTo(0)); } @@ -280,7 +281,7 @@ public void CantDeleteOthersPhoto() message = client1.GetAsync($"/lbp/photos/by?user={user1.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - SerializedPhotoList response = message.Content.ReadAsXML(); + SerializedPhotoList response = message.Content.ReadAsXml(); Assert.That(response.Items, Has.Count.EqualTo(1)); Assert.That(response.Items[0].LargeHash, Is.EqualTo(TEST_ASSET_HASH)); @@ -292,7 +293,7 @@ public void CantDeleteOthersPhoto() message = client1.GetAsync($"/lbp/photos/by?user={user1.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - response = message.Content.ReadAsXML(); + response = message.Content.ReadAsXml(); Assert.That(response.Items, Has.Count.EqualTo(1)); } } \ No newline at end of file diff --git a/RefreshTests.GameServer/Tests/Relations/FavouriteSlotTests.cs b/RefreshTests.GameServer/Tests/Relations/FavouriteSlotTests.cs index 29939e9e..61560aa7 100644 --- a/RefreshTests.GameServer/Tests/Relations/FavouriteSlotTests.cs +++ b/RefreshTests.GameServer/Tests/Relations/FavouriteSlotTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Lists; @@ -25,7 +26,7 @@ public void FavouriteAndUnfavouriteLevel() message = client.GetAsync($"/lbp/favouriteSlots/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure the only entry is the level - SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXML(); + SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().LevelId, Is.EqualTo(level.LevelId)); @@ -37,7 +38,7 @@ public void FavouriteAndUnfavouriteLevel() message = client.GetAsync($"/lbp/favouriteSlots/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure its now empty - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -60,7 +61,7 @@ public void CantFavouriteMissingLevel(bool psp) message = client.GetAsync($"/lbp/favouriteSlots/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure its now empty - SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXML(); + SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -104,7 +105,7 @@ public void CantFavouriteLevelTwice(bool psp) message = client.GetAsync($"/lbp/favouriteSlots/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure it has the level - SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXML(); + SerializedMinimalFavouriteLevelList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().LevelId, Is.EqualTo(level.LevelId)); } diff --git a/RefreshTests.GameServer/Tests/Relations/FavouriteUserTests.cs b/RefreshTests.GameServer/Tests/Relations/FavouriteUserTests.cs index dcb459e0..6c7455f4 100644 --- a/RefreshTests.GameServer/Tests/Relations/FavouriteUserTests.cs +++ b/RefreshTests.GameServer/Tests/Relations/FavouriteUserTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Lists; @@ -25,7 +26,7 @@ public void FavouriteAndUnfavouriteUser() message = client.GetAsync($"/lbp/favouriteUsers/{user1.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure the only entry is the user we favourited - SerializedFavouriteUserList result = message.Content.ReadAsXML(); + SerializedFavouriteUserList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().Handle.Username, Is.EqualTo(user2.Username)); @@ -37,7 +38,7 @@ public void FavouriteAndUnfavouriteUser() message = client.GetAsync($"/lbp/favouriteUsers/{user1.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure its now empty - result = message.Content.ReadAsXML(); + result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -60,7 +61,7 @@ public void CantFavouriteMissingUser(bool psp) message = client.GetAsync($"/lbp/favouriteUsers/{user.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure it is empty - SerializedFavouriteUserList result = message.Content.ReadAsXML(); + SerializedFavouriteUserList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(0)); } @@ -105,7 +106,7 @@ public void CantFavouriteUserTwice(bool psp) message = client.GetAsync($"/lbp/favouriteUsers/{user1.Username}").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); //Make sure it has the user - SerializedFavouriteUserList result = message.Content.ReadAsXML(); + SerializedFavouriteUserList result = message.Content.ReadAsXml(); Assert.That(result.Items, Has.Count.EqualTo(1)); Assert.That(result.Items.First().Handle.Username, Is.EqualTo(user2.Username)); } diff --git a/RefreshTests.GameServer/Tests/Relations/ReviewTests.cs b/RefreshTests.GameServer/Tests/Relations/ReviewTests.cs index c24903eb..4e29389c 100644 --- a/RefreshTests.GameServer/Tests/Relations/ReviewTests.cs +++ b/RefreshTests.GameServer/Tests/Relations/ReviewTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.Reviews; @@ -37,7 +38,7 @@ public void TestPostReview(string slotType) HttpResponseMessage response = reviewerClient.GetAsync($"/lbp/reviewsFor/{slotType}/{levelId}").Result; Assert.That(response.StatusCode, Is.EqualTo(OK)); - SerializedGameReviewResponse levelReviews = response.Content.ReadAsXML(); + SerializedGameReviewResponse levelReviews = response.Content.ReadAsXml(); Assert.That(levelReviews.Items, Has.Count.EqualTo(1)); Assert.That(levelReviews.Items[0].Text, Is.EqualTo(review.Text)); Assert.That(levelReviews.Items[0].Labels, Is.EqualTo(review.Labels)); @@ -45,7 +46,7 @@ public void TestPostReview(string slotType) response = reviewerClient.GetAsync($"/lbp/reviewsBy/{reviewPublisher.Username}").Result; Assert.That(response.StatusCode, Is.EqualTo(OK)); - levelReviews = response.Content.ReadAsXML(); + levelReviews = response.Content.ReadAsXml(); Assert.That(levelReviews.Items, Has.Count.EqualTo(1)); Assert.That(levelReviews.Items[0].Text, Is.EqualTo(review.Text)); Assert.That(levelReviews.Items[0].Labels, Is.EqualTo(review.Labels)); diff --git a/RefreshTests.GameServer/Tests/Users/ActivityEndpointsTests.cs b/RefreshTests.GameServer/Tests/Users/ActivityEndpointsTests.cs index c9b678c8..bbf388db 100644 --- a/RefreshTests.GameServer/Tests/Users/ActivityEndpointsTests.cs +++ b/RefreshTests.GameServer/Tests/Users/ActivityEndpointsTests.cs @@ -1,3 +1,4 @@ +using Refresh.Common.Extensions; using Refresh.GameServer.Authentication; using Refresh.GameServer.Types.Levels; using Refresh.GameServer.Types.News; @@ -23,7 +24,7 @@ public void GetNews() HttpResponseMessage message = client.GetAsync("/lbp/news").Result; Assert.That(message.StatusCode, Is.EqualTo(OK)); - GameNewsResponse response = message.Content.ReadAsXML(); + GameNewsResponse response = message.Content.ReadAsXml(); Assert.That(response.Subcategory.Items, Has.Count.EqualTo(1)); Assert.That(response.Subcategory.Items[0].Subject, Is.EqualTo("Team Pick")); }