Skip to content

Commit

Permalink
Merge pull request #2034 from erri120/fix/2031
Browse files Browse the repository at this point in the history
Fix exception on cached value
  • Loading branch information
Al12rs authored Sep 16, 2024
2 parents 9169f04 + 681377c commit 35d1b1c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 48 deletions.
53 changes: 26 additions & 27 deletions src/Networking/NexusMods.Networking.NexusWebApi/Auth/JWTToken.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
using System.Diagnostics.CodeAnalysis;
using DynamicData.Kernel;
using NexusMods.Abstractions.NexusWebApi.DTOs.OAuth;
using NexusMods.MnemonicDB.Abstractions;
using NexusMods.MnemonicDB.Abstractions.Attributes;
using NexusMods.MnemonicDB.Abstractions.Models;
using NexusMods.MnemonicDB.Abstractions.TxFunctions;
using File = NexusMods.Abstractions.Loadouts.Files.File;

namespace NexusMods.Networking.NexusWebApi.Auth;

Expand All @@ -29,49 +27,50 @@ public partial class JWTToken : IModelDefinition
/// The date at which the token expires
/// </summary>
public static readonly TimestampAttribute ExpiresAt = new(Namespace, nameof(ExpiresAt));



private static Optional<EntityId> GetEntityId(IDb db)
{
var datoms = db.Datoms(PrimaryAttribute);
return datoms.Count == 0 ? Optional<EntityId>.None : datoms[0].E;
}

/// <summary>
/// Try to find the JWT Token in the database.
/// </summary>
public static bool TryFind(IDb db, out ReadOnly token)
{
var found = All(db).FirstOrDefault();
if (found.IsValid())
var entityId = GetEntityId(db);
if (!entityId.HasValue)
{
token = found;
return true;
token = default(ReadOnly);
return false;
}
token = default(ReadOnly);
return false;

token = Load(db, entityId.Value);
return token.IsValid();
}



/// <summary>
/// Creates a new JWT Token model from a <see cref="JwtTokenReply"/>. And reuses the existing
/// database id if it exists, as this data is a singleton.
/// </summary>
public static EntityId? Create(IDb db, ITransaction tx, JwtTokenReply reply)
public static Optional<EntityId> Create(IDb db, ITransaction tx, JwtTokenReply reply)
{
if (reply.AccessToken is null || reply.RefreshToken is null)
return null;

var existingId = db.Datoms(JWTToken.AccessToken).FirstOrDefault().E;
if (existingId == EntityId.From(0))
existingId = tx.TempId();

tx.Add(existingId, JWTToken.AccessToken, reply.AccessToken);
tx.Add(existingId, JWTToken.RefreshToken, reply.RefreshToken);
tx.Add(existingId, JWTToken.ExpiresAt, DateTimeOffset.FromUnixTimeSeconds(reply.CreatedAt).DateTime + TimeSpan.FromSeconds(reply.ExpiresIn));
if (reply.AccessToken is null || reply.RefreshToken is null) return Optional<EntityId>.None;

var existingId = GetEntityId(db);
var entityId = existingId.HasValue ? existingId.Value : tx.TempId();

tx.Add(entityId, AccessToken, reply.AccessToken);
tx.Add(entityId, RefreshToken, reply.RefreshToken);
tx.Add(entityId, ExpiresAt, DateTimeOffset.FromUnixTimeSeconds(reply.CreatedAt).DateTime + TimeSpan.FromSeconds(reply.ExpiresIn));

return existingId;
return entityId;
}

/// <summary>
/// Model for the JWT Token
/// </summary>
/// <param name="tx"></param>
public partial struct ReadOnly
{
/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
using System.Reactive.Linq;
using DynamicData.Binding;
using Microsoft.Extensions.Logging;
using NexusMods.Abstractions.MnemonicDB.Attributes;
using NexusMods.Abstractions.NexusWebApi;
using NexusMods.Abstractions.NexusWebApi.DTOs.OAuth;
using NexusMods.Abstractions.NexusWebApi.Types;
using NexusMods.Abstractions.Serialization;
using NexusMods.Extensions.BCL;
using NexusMods.MnemonicDB.Abstractions;
using NexusMods.MnemonicDB.Abstractions.Query;

namespace NexusMods.Networking.NexusWebApi.Auth;

Expand All @@ -31,40 +25,32 @@ public OAuth2MessageFactory(
_conn = conn;
_auth = auth;
_logger = logger;

_conn.ObserveDatoms(SliceDescriptor.Create(JWTToken.AccessToken, _conn.Registry))
.Subscribe(_ => _cachedTokenEntity = null);
}

private JWTToken.ReadOnly? _cachedTokenEntity;
private readonly IConnection _conn;

private async ValueTask<string?> GetOrRefreshToken(CancellationToken cancellationToken)
{
if (!JWTToken.TryFind(_conn.Db, out var token))
return null;

_cachedTokenEntity = token;
if (!token.HasExpired)
return _cachedTokenEntity!.Value.AccessToken;
if (!JWTToken.TryFind(_conn.Db, out var token)) return null;
if (!token.HasExpired) return token.AccessToken;

_logger.LogDebug("Refreshing expired OAuth token");

var newToken = await _auth.RefreshToken(token.RefreshToken, cancellationToken);
var db = _conn.Db;
using var tx = _conn.BeginTransaction();

var newTokenEntity = JWTToken.Create(db, tx, newToken!);
if (newTokenEntity is null)
if (!newTokenEntity.HasValue)
{
_logger.LogError("Invalid new token in OAuth2MessageFactory");
return null;
}

var result = await tx.Commit();

_cachedTokenEntity = JWTToken.Load(result.Db, result[newTokenEntity.Value]);
return _cachedTokenEntity!.Value.AccessToken;
token = JWTToken.Load(result.Db, result[newTokenEntity.Value]);
return token.AccessToken;
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public async Task LoginAsync(CancellationToken token = default)
using var tx = _conn.BeginTransaction();

var newTokenEntity = JWTToken.Create(_conn.Db, tx, jwtToken);
if (newTokenEntity is null)
if (!newTokenEntity.HasValue)
{
_logger.LogError("Invalid new token data");
return;
Expand Down

0 comments on commit 35d1b1c

Please sign in to comment.