Skip to content

Commit

Permalink
Mark all tracked users as dirty on expired SSS connections
Browse files Browse the repository at this point in the history
See matrix-org/matrix-rust-sdk#3965 for
more information. Requires `Extension.onRequest` to be `async`.
  • Loading branch information
kegsay committed Sep 18, 2024
1 parent 1fd6675 commit 8616294
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 54 deletions.
42 changes: 27 additions & 15 deletions spec/integ/sliding-sync-sdk.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,13 @@ describe("SlidingSyncSdk", () => {
client!.crypto!.stop();
});

it("gets enabled on the initial request only", () => {
expect(ext.onRequest(true)).toEqual({
it("gets enabled all the time", async () => {
expect(await ext.onRequest(true)).toEqual({
enabled: true,
});
expect(await ext.onRequest(false)).toEqual({
enabled: true,
});
expect(ext.onRequest(false)).toEqual(undefined);
});

it("can update device lists", () => {
Expand Down Expand Up @@ -686,11 +688,13 @@ describe("SlidingSyncSdk", () => {
ext = findExtension("account_data");
});

it("gets enabled on the initial request only", () => {
expect(ext.onRequest(true)).toEqual({
it("gets enabled all the time", async () => {
expect(await ext.onRequest(true)).toEqual({
enabled: true,
});
expect(await ext.onRequest(false)).toEqual({
enabled: true,
});
expect(ext.onRequest(false)).toEqual(undefined);
});

it("processes global account data", async () => {
Expand Down Expand Up @@ -814,8 +818,12 @@ describe("SlidingSyncSdk", () => {
ext = findExtension("to_device");
});

it("gets enabled with a limit on the initial request only", () => {
const reqJson: any = ext.onRequest(true);
it("gets enabled all the time", async () => {
let reqJson: any = await ext.onRequest(true);
expect(reqJson.enabled).toEqual(true);
expect(reqJson.limit).toBeGreaterThan(0);
expect(reqJson.since).toBeUndefined();
reqJson = await ext.onRequest(false);
expect(reqJson.enabled).toEqual(true);
expect(reqJson.limit).toBeGreaterThan(0);
expect(reqJson.since).toBeUndefined();
Expand All @@ -826,7 +834,7 @@ describe("SlidingSyncSdk", () => {
next_batch: "12345",
events: [],
});
expect(ext.onRequest(false)).toEqual({
expect(await ext.onRequest(false)).toMatchObject({
since: "12345",
});
});
Expand Down Expand Up @@ -910,11 +918,13 @@ describe("SlidingSyncSdk", () => {
ext = findExtension("typing");
});

it("gets enabled on the initial request only", () => {
expect(ext.onRequest(true)).toEqual({
it("gets enabled all the time", async () => {
expect(await ext.onRequest(true)).toEqual({
enabled: true,
});
expect(await ext.onRequest(false)).toEqual({
enabled: true,
});
expect(ext.onRequest(false)).toEqual(undefined);
});

it("processes typing notifications", async () => {
Expand Down Expand Up @@ -1035,11 +1045,13 @@ describe("SlidingSyncSdk", () => {
ext = findExtension("receipts");
});

it("gets enabled on the initial request only", () => {
expect(ext.onRequest(true)).toEqual({
it("gets enabled all the time", async () => {
expect(await ext.onRequest(true)).toEqual({
enabled: true,
});
expect(await ext.onRequest(false)).toEqual({
enabled: true,
});
expect(ext.onRequest(false)).toEqual(undefined);
});

it("processes receipts", async () => {
Expand Down
14 changes: 7 additions & 7 deletions spec/integ/sliding-sync.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ describe("SlidingSync", () => {
};
const ext: Extension<any, any> = {
name: () => "custom_extension",
onRequest: (initial) => {
return { initial: initial };
onRequest: async (_) => {
return { initial: true };
},
onResponse: async (res) => {
return;
Expand Down Expand Up @@ -827,7 +827,7 @@ describe("SlidingSync", () => {

const extPre: Extension<any, any> = {
name: () => preExtName,
onRequest: (initial) => {
onRequest: async (initial) => {
return onPreExtensionRequest(initial);
},
onResponse: (res) => {
Expand All @@ -837,7 +837,7 @@ describe("SlidingSync", () => {
};
const extPost: Extension<any, any> = {
name: () => postExtName,
onRequest: (initial) => {
onRequest: async (initial) => {
return onPostExtensionRequest(initial);
},
onResponse: (res) => {
Expand All @@ -852,7 +852,7 @@ describe("SlidingSync", () => {

const callbackOrder: string[] = [];
let extensionOnResponseCalled = false;
onPreExtensionRequest = () => {
onPreExtensionRequest = async () => {
return extReq;
};
onPreExtensionResponse = async (resp) => {
Expand Down Expand Up @@ -892,7 +892,7 @@ describe("SlidingSync", () => {
});

it("should be able to send nothing in an extension request/response", async () => {
onPreExtensionRequest = () => {
onPreExtensionRequest = async () => {
return undefined;
};
let responseCalled = false;
Expand Down Expand Up @@ -927,7 +927,7 @@ describe("SlidingSync", () => {

it("is possible to register extensions after start() has been called", async () => {
slidingSync.registerExtension(extPost);
onPostExtensionRequest = () => {
onPostExtensionRequest = async () => {
return extReq;
};
let responseCalled = false;
Expand Down
9 changes: 9 additions & 0 deletions src/common-crypto/CryptoBackend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ export interface SyncCryptoCallbacks {
* @param syncState - information about the completed sync.
*/
onSyncCompleted(syncState: OnSyncCompletedData): void;

/**
* Mark all tracked user's device lists as dirty.
*
* This method will cause additional /keys/query requests on the server, so should be used only
* when the client has desynced tracking device list deltas from the server.
* In MSC4186: Simplified Sliding Sync, this can happen when the server expires the connection.
*/
markAllTrackedUsersAsDirty(): Promise<void>;
}

/**
Expand Down
7 changes: 7 additions & 0 deletions src/crypto/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3445,6 +3445,13 @@ export class Crypto extends TypedEventEmitter<CryptoEvent, CryptoEventHandlerMap
}
}

/**
* Implementation of {@link CryptoApi#markAllTrackedUsersAsDirty}.
*/
public async markAllTrackedUsersAsDirty(): Promise<void> {
// no op: we only expect rust crypto to be used in MSC4186.
}

/**
* Trigger the appropriate invalidations and removes for a given
* device list
Expand Down
45 changes: 20 additions & 25 deletions src/sliding-sync-sdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,16 @@ class ExtensionE2EE implements Extension<ExtensionE2EERequest, ExtensionE2EEResp
return ExtensionState.PreProcess;
}

public onRequest(isInitial: boolean): ExtensionE2EERequest | undefined {
if (!isInitial) {
return undefined;
public async onRequest(isInitial: boolean): Promise<ExtensionE2EERequest> {
if (isInitial) {
// In SSS, the `?pos=` contains the stream position for device list updates.
// If we do not have a `?pos=` (e.g because we forgot it, or because the server
// invalidated our connection) then we MUST invlaidate all device lists because
// the server will not tell us the delta. This will then cause UTDs as we will fail
// to encrypt for new devices. This is an expensive call, so we should
// really really remember `?pos=` wherever possible.
logger.log("ExtensionE2EE: invalidating all device lists due to missing 'pos'");
await this.crypto.markAllTrackedUsersAsDirty();
}
return {
enabled: true, // this is sticky so only send it on the initial request
Expand Down Expand Up @@ -127,15 +134,12 @@ class ExtensionToDevice implements Extension<ExtensionToDeviceRequest, Extension
return ExtensionState.PreProcess;
}

public onRequest(isInitial: boolean): ExtensionToDeviceRequest {
const extReq: ExtensionToDeviceRequest = {
public async onRequest(isInitial: boolean): Promise<ExtensionToDeviceRequest> {
return {
since: this.nextBatch !== null ? this.nextBatch : undefined,
limit: 100,
enabled: true,
};
if (isInitial) {
extReq["limit"] = 100;
extReq["enabled"] = true;
}
return extReq;
}

public async onResponse(data: ExtensionToDeviceResponse): Promise<void> {
Expand Down Expand Up @@ -209,10 +213,7 @@ class ExtensionAccountData implements Extension<ExtensionAccountDataRequest, Ext
return ExtensionState.PostProcess;
}

public onRequest(isInitial: boolean): ExtensionAccountDataRequest | undefined {
if (!isInitial) {
return undefined;
}
public async onRequest(isInitial: boolean): Promise<ExtensionAccountDataRequest> {
return {
enabled: true,
};
Expand Down Expand Up @@ -279,10 +280,7 @@ class ExtensionTyping implements Extension<ExtensionTypingRequest, ExtensionTypi
return ExtensionState.PostProcess;
}

public onRequest(isInitial: boolean): ExtensionTypingRequest | undefined {
if (!isInitial) {
return undefined; // don't send a JSON object for subsequent requests, we don't need to.
}
public async onRequest(isInitial: boolean): Promise<ExtensionTypingRequest> {
return {
enabled: true,
};
Expand Down Expand Up @@ -318,13 +316,10 @@ class ExtensionReceipts implements Extension<ExtensionReceiptsRequest, Extension
return ExtensionState.PostProcess;
}

public onRequest(isInitial: boolean): ExtensionReceiptsRequest | undefined {
if (isInitial) {
return {
enabled: true,
};
}
return undefined; // don't send a JSON object for subsequent requests, we don't need to.
public async onRequest(isInitial: boolean): Promise<ExtensionReceiptsRequest> {
return {
enabled: true,
};
}

public async onResponse(data: ExtensionReceiptsResponse): Promise<void> {
Expand Down
14 changes: 7 additions & 7 deletions src/sliding-sync.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,10 @@ export interface Extension<Req extends {}, Res extends {}> {
/**
* A function which is called when the request JSON is being formed.
* Returns the data to insert under this key.
* @param isInitial - True when this is part of the initial request (send sticky params)
* @param isInitial - True when this is part of the initial request.
* @returns The request JSON to send.
*/
onRequest(isInitial: boolean): Req | undefined;
onRequest(isInitial: boolean): Promise<Req>;
/**
* A function which is called when there is response JSON under this extension.
* @param data - The response JSON under the extension name.
Expand Down Expand Up @@ -471,11 +471,11 @@ export class SlidingSync extends TypedEventEmitter<SlidingSyncEvent, SlidingSync
this.extensions[ext.name()] = ext;
}

private getExtensionRequest(): Record<string, object | undefined> {
private async getExtensionRequest(isInitial: boolean): Promise<Record<string, object | undefined>> {
const ext: Record<string, object | undefined> = {};
Object.keys(this.extensions).forEach((extName) => {
ext[extName] = this.extensions[extName].onRequest(true);
});
for (const extName in this.extensions) {
ext[extName] = await this.extensions[extName].onRequest(isInitial);
}
return ext;
}

Expand Down Expand Up @@ -582,7 +582,7 @@ export class SlidingSync extends TypedEventEmitter<SlidingSyncEvent, SlidingSync
pos: currentPos,
timeout: this.timeoutMS,
clientTimeout: this.timeoutMS + BUFFER_PERIOD_MS,
extensions: this.getExtensionRequest(),
extensions: await this.getExtensionRequest(currentPos === undefined),
};
// check if we are (un)subscribing to a room and modify request this one time for it
const newSubscriptions = difference(this.desiredRoomSubscriptions, this.confirmedRoomSubscriptions);
Expand Down

0 comments on commit 8616294

Please sign in to comment.