Skip to content

Commit

Permalink
Preserve access tokens when refreshing tunnel (#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasongin authored Aug 3, 2023
1 parent ba98b39 commit d54b31b
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 81 deletions.
17 changes: 0 additions & 17 deletions cs/src/Connections/TunnelConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,6 @@ private set
{
if (value != this.tunnel)
{
// Get the tunnel access token from the new tunnel, or the original Tunnal object if the new tunnel doesn't have the token,
// which may happen when the tunnel was authenticated with a tunnel access token from Tunnel.AccessTokens.
// Add the tunnel access token to the new tunnel's AccessTokens if it is not there.

// TODO: remove this access token preservation logic when #990 is fixed.
string? accessToken;
if (value != null &&
!value.TryGetAccessToken(TunnelAccessScope, out var _) &&
this.tunnel?.TryGetAccessToken(TunnelAccessScope, out accessToken) == true &&
!string.IsNullOrEmpty(accessToken) &&
TunnelAccessTokenProperties.TryParse(accessToken) is TunnelAccessTokenProperties tokenProperties &&
(tokenProperties.Expiration == null || tokenProperties.Expiration > DateTime.UtcNow))
{
value.AccessTokens ??= new Dictionary<string, string>();
value.AccessTokens[TunnelAccessScope] = accessToken;
}

this.tunnel = value;
OnTunnelChanged();
}
Expand Down
63 changes: 48 additions & 15 deletions cs/src/Management/TunnelManagementClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ public async Task<Tunnel[]> SearchTunnelsAsync(
query: GetApiQuery(),
options,
cancellation);
PreserveAccessTokens(tunnel, result);
return result;
}

Expand All @@ -920,6 +921,7 @@ public async Task<Tunnel> CreateTunnelAsync(
options,
ConvertTunnelForRequest(tunnel),
cancellation);
PreserveAccessTokens(tunnel, result);
return result!;
}

Expand All @@ -938,14 +940,7 @@ public async Task<Tunnel> UpdateTunnelAsync(
options,
ConvertTunnelForRequest(tunnel),
cancellation);

// If no new tokens were requested in the update, preserve any existing
// access tokens in the resulting tunnel object.
if (options?.TokenScopes == null)
{
result!.AccessTokens = tunnel.AccessTokens;
}

PreserveAccessTokens(tunnel, result);
return result!;
}

Expand Down Expand Up @@ -1086,6 +1081,7 @@ public async Task<TunnelPort> CreateTunnelPortAsync(
options,
ConvertTunnelPortForRequest(tunnel, tunnelPort),
cancellation))!;
PreserveAccessTokens(tunnelPort, result);

if (tunnel.Ports != null)
{
Expand Down Expand Up @@ -1127,6 +1123,7 @@ public async Task<TunnelPort> UpdateTunnelPortAsync(
options,
ConvertTunnelPortForRequest(tunnel, tunnelPort),
cancellation))!;
PreserveAccessTokens(tunnelPort, result);

if (tunnel.Ports != null)
{
Expand All @@ -1138,13 +1135,6 @@ public async Task<TunnelPort> UpdateTunnelPortAsync(
.ToArray();
}

// If no new tokens were requested in the update, preserve any existing
// access tokens in the resulting port object.
if (options?.TokenScopes == null)
{
result!.AccessTokens = tunnelPort.AccessTokens;
}

return result;
}

Expand Down Expand Up @@ -1335,5 +1325,48 @@ public async Task<bool> CheckNameAvailabilityAsync(
{
return string.IsNullOrEmpty(ApiVersion) ? null : $"api-version={ApiVersion}";
}

/// <summary>
/// Copy access tokens from the request object to the result object, except for any
/// tokens that were refreshed by the request.
/// </summary>
/// <remarks>
/// This intentionally does not check whether any existing tokens are expired. So
/// expired tokens may be preserved also, if not refreshed. This allows for better
/// diagnostics in that case.
/// </remarks>
private static void PreserveAccessTokens(Tunnel requestTunnel, Tunnel? resultTunnel)
{
if (requestTunnel.AccessTokens != null && resultTunnel != null)
{
resultTunnel.AccessTokens ??= new Dictionary<string, string>();
foreach (var scopeAndToken in requestTunnel.AccessTokens)
{
if (!resultTunnel.AccessTokens.ContainsKey(scopeAndToken.Key))
{
resultTunnel.AccessTokens[scopeAndToken.Key] = scopeAndToken.Value;
}
}
}
}

/// <summary>
/// Copy access tokens from the request object to the result object, except for any
/// tokens that were refreshed by the request.
/// </summary>
private static void PreserveAccessTokens(TunnelPort requestPort, TunnelPort? resultPort)
{
if (requestPort.AccessTokens != null && resultPort != null)
{
resultPort.AccessTokens ??= new Dictionary<string, string>();
foreach (var scopeAndToken in requestPort.AccessTokens)
{
if (!resultPort.AccessTokens.ContainsKey(scopeAndToken.Key))
{
resultPort.AccessTokens[scopeAndToken.Key] = scopeAndToken.Value;
}
}
}
}
}
}
78 changes: 63 additions & 15 deletions cs/test/TunnelsSDK.Test/TunnelManagementClientTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Net;
using System.Net;
using System.Net.Http.Headers;
using System.Net.Http.Json;
using Microsoft.DevTunnels.Contracts;
Expand All @@ -14,14 +14,9 @@ public class TunnelManagementClientTests
private readonly CancellationToken timeout = System.Diagnostics.Debugger.IsAttached ? default : new CancellationTokenSource(TimeSpan.FromSeconds(5)).Token;
private readonly ProductInfoHeaderValue userAgent = TunnelUserAgent.GetUserAgent(typeof(TunnelManagementClientTests).Assembly);
private readonly Uri tunnelServiceUri = new Uri("https://localhost:3000/");
private readonly Tunnel tunnel = new Tunnel
{
TunnelId = TunnelId,
ClusterId = ClusterId,
};

[Fact]
public async Task TunnelRequestOptions_SetRequestOption()
public async Task HttpRequestOptions()
{
var options = new TunnelRequestOptions()
{
Expand All @@ -32,33 +27,86 @@ public async Task TunnelRequestOptions_SetRequestOption()
}
};

var tunnel = new Tunnel
{
TunnelId = TunnelId,
ClusterId = ClusterId,
};

var handler = new MockHttpMessageHandler(
(message, ct) =>
(message, ct) =>
{
Assert.True(message.Options.TryGetValue(new HttpRequestOptionsKey<string>("foo"), out string strValue) && strValue == "bar");
Assert.True(message.Options.TryGetValue(new HttpRequestOptionsKey<int>("bazz"), out int intValue) && intValue == 100);
return GetTunnelResponseAsync();
var result = new HttpResponseMessage(HttpStatusCode.OK);
result.Content = JsonContent.Create(tunnel);
return Task.FromResult(result);
});

var client = new TunnelManagementClient(this.userAgent, null, this.tunnelServiceUri, handler);
var tunnel = await client.GetTunnelAsync(this.tunnel, options, this.timeout);

tunnel = await client.GetTunnelAsync(tunnel, options, this.timeout);
Assert.NotNull(tunnel);
Assert.Equal(TunnelId, tunnel.TunnelId);
Assert.Equal(ClusterId, tunnel.ClusterId);
}

private Task<HttpResponseMessage> GetTunnelResponseAsync()
[Fact]
public async Task PreserveAccessTokens()
{
var result = new HttpResponseMessage(HttpStatusCode.OK);
result.Content = JsonContent.Create(this.tunnel);
return Task.FromResult(result);
var requestTunnel = new Tunnel
{
TunnelId = TunnelId,
ClusterId = ClusterId,
AccessTokens = new Dictionary<string, string>
{
[TunnelAccessScopes.Manage] = "manage-token-1",
[TunnelAccessScopes.Connect] = "connect-token-1",
},
};

var handler = new MockHttpMessageHandler(
(message, ct) =>
{
var responseTunnel = new Tunnel
{
TunnelId = TunnelId,
ClusterId = ClusterId,
AccessTokens = new Dictionary<string, string>
{
[TunnelAccessScopes.Manage] = "manage-token-2",
[TunnelAccessScopes.Host] = "host-token-2",
},
};
var result = new HttpResponseMessage(HttpStatusCode.OK);
result.Content = JsonContent.Create(responseTunnel);
return Task.FromResult(result);
});
var client = new TunnelManagementClient(this.userAgent, null, this.tunnelServiceUri, handler);

var resultTunnel = await client.GetTunnelAsync(requestTunnel, options: null, this.timeout);
Assert.NotNull(resultTunnel);
Assert.NotNull(resultTunnel.AccessTokens);

// Tokens in the request tunnel should be preserved, unless updated by the response.
Assert.Collection(
resultTunnel.AccessTokens.OrderBy((item) => item.Key),
(item) => Assert.Equal(new KeyValuePair<string, string>(
TunnelAccessScopes.Connect, "connect-token-1"), item), // preserved
(item) => Assert.Equal(new KeyValuePair<string, string>(
TunnelAccessScopes.Host, "host-token-2"), item), // added
(item) => Assert.Equal(new KeyValuePair<string, string>(
TunnelAccessScopes.Manage, "manage-token-2"), item)); // updated

}

private sealed class MockHttpMessageHandler : DelegatingHandler
{
private readonly Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler;

public MockHttpMessageHandler(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler)
public MockHttpMessageHandler(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler)
: base(new HttpClientHandler
{
AllowAutoRedirect = false,
Expand Down
20 changes: 0 additions & 20 deletions ts/src/connections/tunnelConnectionSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,6 @@ export class TunnelConnectionSession extends TunnelConnectionBase implements Tun

private set tunnel(value: Tunnel | null) {
if (value !== this.connectedTunnel) {

// Get the tunnel access token from the new tunnel, or the original Tunnal object if the new tunnel doesn't have the token,
// which may happen when the tunnel was authenticated with a tunnel access token from Tunnel.AccessTokens.
// Add the tunnel access token to the new tunnel's AccessTokens if it is not there.

// TODO: remove this access token preservation logic when #990 is fixed.
if (value &&
!TunnelAccessTokenProperties.getTunnelAccessToken(value, this.tunnelAccessScope)) {

const accessToken = TunnelAccessTokenProperties.getTunnelAccessToken(
this.tunnel,
this.tunnelAccessScope,
);

if (accessToken) {
value.accessTokens ??= {};
value.accessTokens[this.tunnelAccessScope] = accessToken;
}
}

this.connectedTunnel = value;
this.tunnelChanged();
}
Expand Down
41 changes: 27 additions & 14 deletions ts/src/management/tunnelManagementHttpClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ function parseTunnelPortDates(port: TunnelPort | null) {
}
}

/**
* Copy access tokens from the request object to the result object, except for any
* tokens that were refreshed by the request.
*/
function preserveAccessTokens<T extends Tunnel | TunnelPort>(
requestObject: T,
resultObject: T | null,
) {
// This intentionally does not check whether any existing tokens are expired. So
// expired tokens may be preserved also, if not refreshed. This allows for better
// diagnostics in that case.
if (requestObject.accessTokens && resultObject) {
resultObject.accessTokens ??= {};
for (const scopeAndToken of Object.entries(requestObject.accessTokens)) {
if (!resultObject.accessTokens[scopeAndToken[0]]) {
resultObject.accessTokens[scopeAndToken[0]] = scopeAndToken[1];
}
}
}
}


const manageAccessTokenScope = [TunnelAccessScopes.Manage];
const hostAccessTokenScope = [TunnelAccessScopes.Host];
const managePortsAccessTokenScopes = [
Expand Down Expand Up @@ -200,6 +222,7 @@ export class TunnelManagementHttpClient implements TunnelManagementClient {
undefined,
options,
);
preserveAccessTokens(tunnel, result);
parseTunnelDates(result);
return result;
}
Expand All @@ -219,6 +242,7 @@ export class TunnelManagementHttpClient implements TunnelManagementClient {
options,
tunnel,
))!;
preserveAccessTokens(tunnel, result);
parseTunnelDates(result);
return result;
}
Expand All @@ -233,13 +257,7 @@ export class TunnelManagementHttpClient implements TunnelManagementClient {
options,
this.convertTunnelForRequest(tunnel),
))!;

if (!options?.tokenScopes) {
// If no new tokens were requested in the update, preserve any existing
// access tokens in the resulting tunnel object.
result.accessTokens = tunnel.accessTokens;
}

preserveAccessTokens(tunnel, result);
parseTunnelDates(result);
return result;
}
Expand Down Expand Up @@ -412,6 +430,8 @@ export class TunnelManagementHttpClient implements TunnelManagementClient {
options,
tunnelPort,
))!;
preserveAccessTokens(tunnelPort, result);
parseTunnelPortDates(result);

if (tunnel.ports) {
// Also update the port in the local tunnel object.
Expand All @@ -421,13 +441,6 @@ export class TunnelManagementHttpClient implements TunnelManagementClient {
.sort(comparePorts);
}

if (!options?.tokenScopes) {
// If no new tokens were requested in the update, preserve any existing
// access tokens in the resulting port object.
result.accessTokens = tunnelPort.accessTokens;
}

parseTunnelPortDates(result);
return result;
}

Expand Down
Loading

0 comments on commit d54b31b

Please sign in to comment.