Skip to content

Commit

Permalink
Merge pull request #666 from erri120/issue-530-oauth
Browse files Browse the repository at this point in the history
OAuth Rework
  • Loading branch information
Al12rs authored Sep 27, 2023
2 parents 9b34405 + 36b2ecd commit fd39fed
Show file tree
Hide file tree
Showing 27 changed files with 496 additions and 325 deletions.
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
<!-- System -->
<PackageVersion Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
<PackageVersion Include="System.CommandLine.NamingConventionBinder" Version="2.0.0-beta4.22272.1" />
<PackageVersion Include="System.IdentityModel.Tokens.Jwt" Version="7.0.0" />
<PackageVersion Include="System.IdentityModel.Tokens.Jwt" Version="7.0.0-preview3" />
<PackageVersion Include="System.IO.Hashing" Version="8.0.0-rc.1.23419.4" />
<PackageVersion Include="System.Linq.Async" Version="6.0.1" />
<PackageVersion Include="System.Reactive" Version="6.0.0" />
Expand Down
1 change: 1 addition & 0 deletions NexusMods.App.sln.DotSettings
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=DI/@EntryIndexedValue">DI</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=EA/@EntryIndexedValue">EA</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=GOG/@EntryIndexedValue">GOG</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=JWT/@EntryIndexedValue">JWT</s:String>
<s:String x:Key="/Default/CodeStyle/Naming/CSharpNaming/Abbreviations/=VM/@EntryIndexedValue">VM</s:String>
<s:Boolean x:Key="/Default/UserDictionary/Words/=ative/@EntryIndexedValue">True</s:Boolean>
<s:Boolean x:Key="/Default/UserDictionary/Words/=Avalonia/@EntryIndexedValue">True</s:Boolean>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class ApiKeyMessageFactory : IAuthenticatingMessageFactory
/// The name of the environment variable that contains the API key.
/// </summary>
public const string NexusApiKeyEnvironmentVariable = "NEXUS_API_KEY";

private static readonly IId ApiKeyId = new IdVariableLength(EntityCategory.AuthData, "NexusMods.Networking.NexusWebApi.ApiKey"u8.ToArray());

private readonly IDataStore _store;
Expand Down Expand Up @@ -75,7 +75,7 @@ public ValueTask SetApiKey(string apiKey)
{
Name = result.Data.Name,
IsPremium = result.Data.IsPremium,
IsSupporter = result.Data.IsSupporter,
AvatarUrl = result.Data.ProfileUrl
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,9 @@
namespace NexusMods.Networking.NexusWebApi.NMA;

/// <summary>
/// entity to store a JWT token
/// TODO: Right now we follow a "ask for forgiveness, not permission" approach to using the token,
/// so we use the access token until we get an error indicating it has expired, then refresh the
/// token and retry. This way we don't need to store when the token expires even though we have
/// that information. If we wanted to save one request every six hours or if the lifetime of access tokens
/// changes, we might want to refresh tokens more proactively and then we'd need to save the expire time.
/// Represents a JWT Token in our DataStore.
/// </summary>
[JsonName("JWTTokens")]
// ReSharper disable once InconsistentNaming
public record JWTTokenEntity : Entity
{
/// <summary>
Expand All @@ -22,17 +16,33 @@ public record JWTTokenEntity : Entity
public static readonly IId StoreId = new IdVariableLength(EntityCategory.AuthData, "NexusMods.Networking.NexusWebApi.JWTTokens"u8.ToArray());

/// <inheritdoc/>
public override EntityCategory Category => EntityCategory.AuthData;
public override EntityCategory Category => StoreId.Category;

/// <summary>
/// the current access token
/// Gets the access token.
/// </summary>
/// <remarks>
/// This token expires at <see cref="ExpiresAt"/> and needs to be refreshed using <see cref="RefreshToken"/>.
/// </remarks>
public required string AccessToken { get; init; }

/// <summary>
/// token needed to generate a new access token when the current one has expired.
/// Gets the refresh token.
/// </summary>
public required string RefreshToken { get; init; }

/// <summary>
/// Gets the date at which the <see cref="AccessToken"/> expires.
/// </summary>
public required DateTimeOffset ExpiresAt { get; init; }

/// <summary>
/// Checks whether the token has expired.
/// </summary>
public bool HasExpired()
{
return ExpiresAt - TimeSpan.FromMinutes(5) <= DateTimeOffset.UtcNow;
}
}


60 changes: 48 additions & 12 deletions src/Networking/NexusMods.Networking.NexusWebApi.NMA/LoginManager.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System.Reactive.Concurrency;
using System.Reactive.Linq;
using JetBrains.Annotations;
using NexusMods.Common;
using NexusMods.Common.ProtocolRegistration;
using NexusMods.DataModel.Abstractions;
using NexusMods.Networking.NexusWebApi.Types;
Expand All @@ -8,7 +11,8 @@ namespace NexusMods.Networking.NexusWebApi.NMA;
/// <summary>
/// Component for handling login and logout from the Nexus Mods
/// </summary>
public class LoginManager
[PublicAPI]
public sealed class LoginManager : IDisposable
{
private readonly OAuth _oauth;
private readonly IDataStore _dataStore;
Expand All @@ -24,7 +28,7 @@ public class LoginManager
/// <summary>
/// True if the user is logged in
/// </summary>
public IObservable<bool> IsLoggedIn => UserInfo.Select(info => info != null);
public IObservable<bool> IsLoggedIn => UserInfo.Select(info => info is not null);

/// <summary>
/// True if the user is logged in and is a premium member
Expand All @@ -34,7 +38,7 @@ public class LoginManager
/// <summary>
/// The user's avatar
/// </summary>
public IObservable<Uri?> Avatar => UserInfo.Select(info => info?.Avatar);
public IObservable<Uri?> Avatar => UserInfo.Select(info => info?.AvatarUrl);

/// <summary/>
/// <param name="client">Nexus API client.</param>
Expand All @@ -44,25 +48,45 @@ public class LoginManager
/// <param name="protocolRegistration">Used to register NXM protocol.</param>
public LoginManager(Client client,
IAuthenticatingMessageFactory msgFactory,
OAuth oauth, IDataStore dataStore, IProtocolRegistration protocolRegistration)
OAuth oauth,
IDataStore dataStore,
IProtocolRegistration protocolRegistration)
{
_oauth = oauth;
_msgFactory = msgFactory;
_client = client;
_dataStore = dataStore;
_protocolRegistration = protocolRegistration;

UserInfo = _dataStore.IdChanges
// NOTE(err120): Since IDs don't change on startup, we can insert
// a fake change at the start of the observable chain. This will only
// run once at startup and notify the subscribers.
.Merge(Observable.Return(JWTTokenEntity.StoreId))
.Where(id => id.Equals(JWTTokenEntity.StoreId))
.Select(_ => true)
.StartWith(true)
.SelectMany(async _ => await Verify());
.ObserveOn(TaskPoolScheduler.Default)
.SelectMany(async _ => await Verify(CancellationToken.None));
}

private async Task<UserInfo?> Verify()
private CachedObject<UserInfo> _cachedUserInfo = new(TimeSpan.FromHours(1));
private readonly SemaphoreSlim _verifySemaphore = new(initialCount: 1, maxCount: 1);

private async Task<UserInfo?> Verify(CancellationToken cancellationToken)
{
if (await _msgFactory.IsAuthenticated())
return await _msgFactory.Verify(_client, CancellationToken.None);
return null;
var cachedValue = _cachedUserInfo.Get();
if (cachedValue is not null) return cachedValue;

using var waiter = _verifySemaphore.CustomWait(cancellationToken);
cachedValue = _cachedUserInfo.Get();
if (cachedValue is not null) return cachedValue;

var isAuthenticated = await _msgFactory.IsAuthenticated();
if (!isAuthenticated) return null;

var userInfo = await _msgFactory.Verify(_client, cancellationToken);
_cachedUserInfo.Store(userInfo);

return userInfo;
}

/// <summary>
Expand All @@ -75,10 +99,15 @@ public async Task LoginAsync(CancellationToken token = default)
await _protocolRegistration.RegisterSelf("nxm");

var jwtToken = await _oauth.AuthorizeRequest(token);
var createdAt = DateTimeOffset.FromUnixTimeSeconds(jwtToken.CreatedAt);
var expiresIn = TimeSpan.FromSeconds(jwtToken.ExpiresIn);
var expiresAt = createdAt + expiresIn;

_dataStore.Put(JWTTokenEntity.StoreId, new JWTTokenEntity
{
RefreshToken = jwtToken.RefreshToken,
AccessToken = jwtToken.AccessToken
AccessToken = jwtToken.AccessToken,
ExpiresAt = expiresAt
});
}

Expand All @@ -88,6 +117,13 @@ public async Task LoginAsync(CancellationToken token = default)
public Task Logout()
{
_dataStore.Delete(JWTTokenEntity.StoreId);
_cachedUserInfo.Evict();
return Task.CompletedTask;
}

/// <inheritdoc/>
public void Dispose()
{
_verifySemaphore.Dispose();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\NexusMods.DataModel\NexusMods.DataModel.csproj" />
<ProjectReference Include="..\NexusMods.Networking.NexusWebApi\NexusMods.Networking.NexusWebApi.csproj" />
<ProjectReference Include="..\..\NexusMods.DataModel\NexusMods.DataModel.csproj" />
<ProjectReference Include="..\NexusMods.Networking.NexusWebApi\NexusMods.Networking.NexusWebApi.csproj" />
</ItemGroup>

<ItemGroup>
<InternalsVisibleTo Include="NexusMods.Networking.NexusWebApi.Tests" />
</ItemGroup>

</Project>
90 changes: 49 additions & 41 deletions src/Networking/NexusMods.Networking.NexusWebApi.NMA/OAuth.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Reactive.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Text.Encodings.Web;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using NexusMods.Common;
Expand All @@ -20,10 +21,10 @@ namespace NexusMods.Networking.NexusWebApi.NMA;
public class OAuth
{
private const string OAuthUrl = "https://users.nexusmods.com/oauth";
// the redirect url has to explicitly be permitted by the server so we can't change
// this without consulting the backend team
// NOTE(erri120): The backend has a list of valid redirect URLs and client IDs.
// We can't change these on our own.
private const string OAuthRedirectUrl = "nxm://oauth/callback";
private const string OAuthClientId = "vortex";
private const string OAuthClientId = "nma";

private readonly ILogger<OAuth> _logger;
private readonly HttpClient _http;
Expand All @@ -35,8 +36,11 @@ public class OAuth
/// <summary>
/// constructor
/// </summary>
public OAuth(ILogger<OAuth> logger, HttpClient http, IIDGenerator idGen,
IOSInterop os, IMessageConsumer<NXMUrlMessage> nxmUrlMessages,
public OAuth(ILogger<OAuth> logger,
HttpClient http,
IIDGenerator idGen,
IOSInterop os,
IMessageConsumer<NXMUrlMessage> nxmUrlMessages,
IInterprocessJobManager jobManager)
{
_logger = logger;
Expand All @@ -48,39 +52,38 @@ public OAuth(ILogger<OAuth> logger, HttpClient http, IIDGenerator idGen,
}

/// <summary>
/// make an authorization request
/// Make an authorization request
/// </summary>
/// <param name="cancel"></param>
/// <param name="cancellationToken"></param>
/// <returns>task with the jwt token once we receive one</returns>
public async Task<JwtTokenReply> AuthorizeRequest(CancellationToken cancel)
public async Task<JwtTokenReply> AuthorizeRequest(CancellationToken cancellationToken)
{
_logger.LogInformation("Starting NexusMods OAuth2 authorization request");
var state = _idGen.UUIDv4();

// see https://www.rfc-editor.org/rfc/rfc7636#section-4.1
var verifier = _idGen.UUIDv4().Replace("-", "").ToBase64();
var codeVerifier = _idGen.UUIDv4().ToBase64();

// see https://www.rfc-editor.org/rfc/rfc7636#section-4.2
using var sha256 = SHA256.Create();
var challenge = sha256.ComputeHash(Encoding.UTF8.GetBytes(verifier)).ToBase64();
var codeChallengeBytes = SHA256.HashData(Encoding.UTF8.GetBytes(codeVerifier));
var codeChallenge = StringEncodingExtension.Base64UrlEncode(codeChallengeBytes);

var state = _idGen.UUIDv4();

// Start listening first, otherwise we might miss the message
var codeTask = _nxmUrlMessages.Messages
.Where(url => url.Value.UrlType == NXMUrlType.OAuth)
.Where(url => url.Value.OAuth.State == state)
.Select(url => url.Value.OAuth.Code!)
.Where(url => url.Value.UrlType == NXMUrlType.OAuth && url.Value.OAuth.State == state)
.Select(url => url.Value.OAuth.Code)
.Where(code => code is not null)
.Select(code => code!)
.ToAsyncEnumerable()
.FirstAsync(cancel);
.FirstAsync(cancellationToken);

_logger.LogInformation("Opening browser for NexusMods OAuth2 authorization request");
var url = GenerateAuthorizeUrl(challenge, state);
var url = GenerateAuthorizeUrl(codeChallenge, state);
using var job = CreateJob(url);

// see https://www.rfc-editor.org/rfc/rfc7636#section-4.3
await _os.OpenUrl(url, cancel);
await _os.OpenUrl(url, cancellationToken);
var code = await codeTask;

_logger.LogInformation("Received OAuth2 authorization code, requesting token");
return await AuthorizeToken(verifier, code, cancel);

return await AuthorizeToken(codeVerifier, code, cancellationToken);
}

private IInterprocessJob CreateJob(Uri url)
Expand Down Expand Up @@ -127,37 +130,42 @@ private async Task<JwtTokenReply> AuthorizeToken(string verifier, string code, C
return JsonSerializer.Deserialize<JwtTokenReply>(responseString);
}

private string SanitizeBase64(string input)
{
return input
.Replace("+", "-")
.Replace("/", "_")
.TrimEnd('=');
}

private Uri GenerateAuthorizeUrl(string challenge, string state)
internal static Uri GenerateAuthorizeUrl(string challenge, string state)
{
// TODO: switch to Microsoft.AspNetCore.WebUtilities when .NET 8 is available
var request = new Dictionary<string, string>
{
{ "response_type", "code" },
{ "scope", "public" },
{ "scope", "openid profile email" },
{ "code_challenge_method", "S256" },
{ "client_id", OAuthClientId },
{ "redirect_uri", OAuthRedirectUrl },
{ "code_challenge", SanitizeBase64(challenge) },
{ "code_challenge", challenge },
{ "state", state },
};
return new Uri($"{OAuthUrl}/authorize?{StringifyRequest(request)}");

return new Uri($"{OAuthUrl}/authorize{CreateQueryString(request)}");
}

private string StringifyRequest(IDictionary<string, string> input)
private static string CreateQueryString(Dictionary<string, string> parameters)
{
IList<string> properties = new List<string>();
foreach (var kv in input)
var builder = new StringBuilder();
var first = true;
foreach (var pair in parameters)
{
properties.Add($"{kv.Key}={Uri.EscapeDataString(kv.Value)}");
var (key, value) = pair;

builder.Append(first ? '?' : '&');
builder.Append(UrlEncoder.Default.Encode(key));
builder.Append('=');
if (!string.IsNullOrEmpty(value))
{
builder.Append(UrlEncoder.Default.Encode(value));
}

first = false;
}

return string.Join("&", properties);
return builder.ToString();
}
}
Loading

0 comments on commit fd39fed

Please sign in to comment.