Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor auth flow session handling #129

Merged
merged 11 commits into from
Jul 23, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ public interface IAuthFlowSessionStorage
/// Deletes the authorization session record by the session identifier.
/// </summary>
/// <param name="context">Agent Context</param>
/// <param name="sessionId">Session Identifier of a Authorization Code Flow session</param>
/// <param name="vciSessionState">Session State Identifier of a Authorization Code Flow session</param>
/// <returns></returns>
Task<bool> DeleteAsync(IAgentContext context, VciSessionId sessionId);
Task<bool> DeleteAsync(IAgentContext context, VciSessionState vciSessionState);

/// <summary>
/// Retrieves the authorization session record by the session identifier.
/// </summary>
/// <param name="context">Agent Context</param>
/// <param name="sessionId">Session Identifier of a Authorization Code Flow session</param>
/// <param name="vciSessionState">Session State Identifier of a Authorization Code Flow session</param>
/// <returns></returns>
Task<AuthFlowSessionRecord> GetAsync(IAgentContext context, VciSessionId sessionId);
Task<AuthFlowSessionRecord> GetAsync(IAgentContext context, VciSessionState vciSessionState);

/// <summary>
/// Stores the authorization session record.
Expand All @@ -35,11 +35,11 @@ public interface IAuthFlowSessionStorage
/// Parameters required for the authorization during the VCI authorization code
/// flow.
/// </param>
/// <param name="sessionId"></param>
/// <param name="vciSessionState">Session State Identifier of a Authorization Code Flow session</param>
/// <returns></returns>
Task<string> StoreAsync(
IAgentContext agentContext,
AuthorizationData authorizationData,
AuthorizationCodeParameters authorizationCodeParameters,
VciSessionId sessionId);
VciSessionState vciSessionState);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using WalletFramework.Core.Functional;

namespace WalletFramework.Oid4Vc.Oid4Vci.AuthFlow.Errors;

public record VciSessionStateError(string Value) : Error($"Invalid VciSessionState: {Value}");
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,26 @@ public async Task<string> StoreAsync(
IAgentContext agentContext,
AuthorizationData authorizationData,
AuthorizationCodeParameters authorizationCodeParameters,
VciSessionId sessionId)
VciSessionState vciSessionState)
{
var record = new AuthFlowSessionRecord(
authorizationData,
authorizationCodeParameters,
sessionId);
vciSessionState);

await _recordService.AddAsync(agentContext.Wallet, record, AuthFlowSessionRecordFun.EncodeToJson);

return record.Id;
}

/// <inheritdoc />
public async Task<AuthFlowSessionRecord> GetAsync(IAgentContext context, VciSessionId sessionId)
public async Task<AuthFlowSessionRecord> GetAsync(IAgentContext context, VciSessionState vciSessionState)
{
var record = await _recordService.GetAsync(context.Wallet, sessionId, AuthFlowSessionRecordFun.DecodeFromJson);
var record = await _recordService.GetAsync(context.Wallet, vciSessionState, AuthFlowSessionRecordFun.DecodeFromJson);
return record!;
}

/// <inheritdoc />
public async Task<bool> DeleteAsync(IAgentContext context, VciSessionId sessionId) =>
await _recordService.DeleteAsync<AuthFlowSessionRecord>(context.Wallet, sessionId);
public async Task<bool> DeleteAsync(IAgentContext context, VciSessionState vciSessionState) =>
await _recordService.DeleteAsync<AuthFlowSessionRecord>(context.Wallet, vciSessionState);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ public record IssuanceSession
/// <summary>
/// Gets the session identifier.
/// </summary>
public VciSessionId SessionId { get; }
public VciSessionState VciSessionState { get; }

/// <summary>
/// Gets the actual authorization code that is received from the authorization server upon successful authorization.
/// </summary>
public string Code { get; }

private IssuanceSession(VciSessionId sessionId, string code) => (SessionId, Code) = (sessionId, code);
private IssuanceSession(VciSessionState vciSessionState, string code) => (VciSessionState, Code) = (vciSessionState, code);

/// <summary>
/// Creates a new instance of <see cref="IssuanceSession"/> from the given <see cref="Uri"/>.
Expand All @@ -36,9 +36,9 @@ public static IssuanceSession FromUri(Uri uri)
throw new InvalidOperationException("Query parameter 'code' is missing");
}

var sessionIdParam = queryParams.Get("session");
var sessionId = VciSessionId.ValidSessionId(sessionIdParam).Fallback(VciSessionId.CreateSessionId());
var sessionStateParam = queryParams.Get("state");
var vciSessionState = VciSessionState.ValidVciSessionState(sessionStateParam).Fallback(VciSessionState.CreateVciSessionState());

return new IssuanceSession(sessionId, code);
return new IssuanceSession(vciSessionState, code);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ internal record PushedAuthorizationRequest

[JsonProperty("code_challenge_method")]
public string CodeChallengeMethod { get; }

[JsonProperty("state", NullValueHandling = NullValueHandling.Ignore)]
public string VciSessionState { get; }

[JsonProperty("authorization_details", NullValueHandling = NullValueHandling.Ignore)]
public AuthorizationDetails[]? AuthorizationDetails { get; }
Expand All @@ -39,7 +42,7 @@ internal record PushedAuthorizationRequest
public string? Resource { get; }

public PushedAuthorizationRequest(
VciSessionId sessionId,
VciSessionState vciSessionState,
ClientOptions clientOptions,
AuthorizationCodeParameters authorizationCodeParameters,
AuthorizationDetails[]? authorizationDetails,
Expand All @@ -49,10 +52,11 @@ public PushedAuthorizationRequest(
string? resource)
{
ClientId = clientOptions.ClientId;
RedirectUri = clientOptions.RedirectUri + "?session=" + sessionId;
RedirectUri = clientOptions.RedirectUri;
WalletIssuer = clientOptions.WalletIssuer;
CodeChallenge = authorizationCodeParameters.Challenge;
CodeChallengeMethod = authorizationCodeParameters.CodeChallengeMethod;
VciSessionState = vciSessionState;
AuthorizationDetails = authorizationDetails;
IssuerState = issuerState;
UserHint = userHint;
Expand All @@ -79,6 +83,9 @@ public FormUrlEncodedContent ToFormUrlEncoded()
if (!string.IsNullOrEmpty(CodeChallengeMethod))
keyValuePairs.Add(new KeyValuePair<string, string>("code_challenge_method", CodeChallengeMethod));

if (!string.IsNullOrEmpty(VciSessionState))
keyValuePairs.Add(new KeyValuePair<string, string>("state", VciSessionState));

if (AuthorizationDetails != null)
keyValuePairs.Add(new KeyValuePair<string, string>("authorization_details", SerializeObject(AuthorizationDetails)));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System.Globalization;
using Newtonsoft.Json.Linq;
using WalletFramework.Core.Functional;
using WalletFramework.Oid4Vc.Oid4Vci.AuthFlow.Errors;

namespace WalletFramework.Oid4Vc.Oid4Vci.AuthFlow.Models;

/// <summary>
/// Identifier of the authorization state during the VCI Authorization Code Flow.
/// </summary>
public struct VciSessionState
{
/// <summary>
/// Gets the value of the state identifier.
/// </summary>
private string Value { get; }

private VciSessionState(string value) => Value = value;

/// <summary>
/// Returns the value of the state identifier.
/// </summary>
/// <param name="vciSessionState"></param>
/// <returns></returns>
public static implicit operator string(VciSessionState vciSessionState) => vciSessionState.Value;

public static Validation<VciSessionState> ValidVciSessionState(string vciSessionState)
{
if (!Guid.TryParse(vciSessionState, out _))
{
return new VciSessionStateError(vciSessionState);
}

return new VciSessionState(vciSessionState);
}

public static VciSessionState CreateVciSessionState()
{
var guid = Guid.NewGuid().ToString();
return new VciSessionState(guid);
}
}

public static class VciSessionStateFun
{
public static VciSessionState DecodeFromJson(JValue json) => VciSessionState
.ValidVciSessionState(json.ToString(CultureInfo.InvariantCulture))
.UnwrapOrThrow(new InvalidOperationException("VciSessionState is corrupt"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ namespace WalletFramework.Oid4Vc.Oid4Vci.AuthFlow.Records;
public sealed class AuthFlowSessionRecord : RecordBase
{
/// <summary>
/// The session specific id.
/// The session specific state id.
/// </summary>
[JsonIgnore]
public VciSessionId SessionId
public VciSessionState VciSessionState
{
get => VciSessionId
.ValidSessionId(Id)
get => VciSessionState
.ValidVciSessionState(Id)
.UnwrapOrThrow(new InvalidOperationException("SessionId is corrupt"));
set
{
Expand Down Expand Up @@ -56,13 +56,13 @@ public AuthFlowSessionRecord()
/// </summary>
/// <param name="authorizationData"></param>
/// <param name="authorizationCodeParameters"></param>
/// <param name="sessionId"></param>
/// <param name="vciSessionState"></param>
public AuthFlowSessionRecord(
AuthorizationData authorizationData,
AuthorizationCodeParameters authorizationCodeParameters,
VciSessionId sessionId)
VciSessionState vciSessionState)
{
SessionId = sessionId;
VciSessionState = vciSessionState;
RecordVersion = 1;
AuthorizationCodeParameters = authorizationCodeParameters;
AuthorizationData = authorizationData;
Expand Down Expand Up @@ -90,7 +90,7 @@ public static JObject EncodeToJson(this AuthFlowSessionRecord record)
public static AuthFlowSessionRecord DecodeFromJson(JObject json)
{
var idJson = json[nameof(RecordBase.Id)]!.ToObject<JValue>()!;
var id = VciSessionIdFun.DecodeFromJson(idJson);
var id = VciSessionStateFun.DecodeFromJson(idJson);

var authCodeParameters = JsonConvert.DeserializeObject<AuthorizationCodeParameters>(
json[AuthorizationCodeParametersJsonKey]!.ToString()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ public static Option<AuthorizationCode> OptionalAuthorizationCode(JToken authori
.GetByKey("authorization_server")
.OnSuccess(token => token.ToString())
.ToOption();

if (issuerState.IsNone && authServer.IsNone)
return Option<AuthorizationCode>.None;

return new AuthorizationCode(issuerState, authServer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public Oid4VciClientService(
public async Task<Uri> InitiateAuthFlow(CredentialOfferMetadata offer, ClientOptions clientOptions)
{
var authorizationCodeParameters = CreateAndStoreCodeChallenge();
var sessionId = VciSessionId.CreateSessionId();
var sessionId = VciSessionState.CreateVciSessionState();
var issuerMetadata = offer.IssuerMetadata;

var scopes = offer
Expand Down Expand Up @@ -241,7 +241,7 @@ public async Task<Validation<OneOf<SdJwtRecord, MdocRecord>>> RequestCredential(
{
var context = await _agentProvider.GetContextAsync();

var session = await _authFlowSessionStorage.GetAsync(context, issuanceSession.SessionId);
var session = await _authFlowSessionStorage.GetAsync(context, issuanceSession.VciSessionState);

var credConfiguration = session
.AuthorizationData
Expand All @@ -254,7 +254,7 @@ public async Task<Validation<OneOf<SdJwtRecord, MdocRecord>>> RequestCredential(
var tokenRequest = new TokenRequest
{
GrantType = AuthorizationCodeGrantTypeIdentifier,
RedirectUri = session.AuthorizationData.ClientOptions.RedirectUri + "?session=" + session.SessionId,
RedirectUri = session.AuthorizationData.ClientOptions.RedirectUri,
CodeVerifier = session.AuthorizationCodeParameters.Verifier,
Code = issuanceSession.Code,
ClientId = session.AuthorizationData.ClientOptions.ClientId
Expand All @@ -270,7 +270,7 @@ public async Task<Validation<OneOf<SdJwtRecord, MdocRecord>>> RequestCredential(
token,
session.AuthorizationData.ClientOptions);

await _authFlowSessionStorage.DeleteAsync(context, session.SessionId);
await _authFlowSessionStorage.DeleteAsync(context, session.VciSessionState);

var result =
from response in validResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,22 @@ public class HaipAuthorizationRequestUri
/// </summary>
public string RequestUri { get; set; } = null!;

/// <summary>
/// Validates the hap conformity of an uri and returns a HaipAuthorizationRequestUri.
/// </summary>
/// <param name="uri"></param>
/// <returns>The HaipAuthorizationRequestUri</returns>
/// <exception cref="InvalidOperationException"></exception>
public static HaipAuthorizationRequestUri FromUri(Uri uri)
{
if (!(uri.Scheme == "haip" | uri.Scheme == "openid4vp"))
throw new InvalidOperationException("Invalid Scheme. Must be haip or openid4vp");

var request = uri.GetQueryParam("request_uri");
if (string.IsNullOrEmpty(request))
throw new InvalidOperationException("HAIP requires request_uri parameter");
/// <summary>
/// Validates the hap conformity of an uri and returns a HaipAuthorizationRequestUri.
/// </summary>
/// <param name="uri"></param>
/// <returns>The HaipAuthorizationRequestUri</returns>
/// <exception cref="InvalidOperationException"></exception>
public static HaipAuthorizationRequestUri FromUri(Uri uri)
{
var request = uri.GetQueryParam("request_uri");
if (string.IsNullOrEmpty(request))
throw new InvalidOperationException("HAIP requires request_uri parameter");

return new HaipAuthorizationRequestUri()
{
RequestUri = request,
Uri = uri
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ private static SdJwtDoc _toSdJwtDoc(SdJwtRecord record)
{
return new SdJwtDoc(record.EncodedIssuerSignedJwt + "~" + string.Join("~", record.Disclosures) + "~");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void Can_Encode_To_Json()

var authorizationCodeParameters = new AuthorizationCodeParameters("hello", "world");

var sessionId = VciSessionId.CreateSessionId();
var sessionId = VciSessionState.CreateVciSessionState();
var record = new AuthFlowSessionRecord(authorizationData, authorizationCodeParameters, sessionId);

// Act
Expand All @@ -58,7 +58,7 @@ public void Can_Encode_To_Json()

// Assert
recordSut[nameof(RecordBase.Id)]!.ToString().Should().Be(record.Id);
tagsSut[nameof(AuthFlowSessionRecord.SessionId)] = record.SessionId.ToString();
tagsSut[nameof(AuthFlowSessionRecord.VciSessionState)] = record.VciSessionState.ToString();
}

[Fact]
Expand Down
Loading