Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Bulk claim OTKs
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Oct 27, 2023
1 parent d7968df commit 62fc3bd
Showing 1 changed file with 61 additions and 48 deletions.
109 changes: 61 additions & 48 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -1131,25 +1132,31 @@ async def claim_e2e_one_time_keys(
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
for user_id, device_id, algorithm, count in query_list:
claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_key_returning,
user_id,
device_id,
algorithm,
count,
db_autocommit=True,
unfulfilled_claim_counts[user_id, device_id, algorithm] = count

bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_key_returning,
query_list,
db_autocommit=True,
)

for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
for claim_row in claim_rows:
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
# Did we get enough OTKs?
count -= len(claim_rows)
if count:
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1

# Did we get enough OTKs?
for (
user_id,
device_id,
algorithm,
), count in unfulfilled_claim_counts.items():
if count > 0:
missing.append((user_id, device_id, algorithm, count))
else:
for user_id, device_id, algorithm, count in query_list:
Expand Down Expand Up @@ -1277,43 +1284,49 @@ def _claim_e2e_one_time_key_simple(
def _claim_e2e_one_time_key_returning(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
query_list: Collection[Tuple[str, str, str, int]],
) -> List[Tuple[str, str, str, str, str]]:
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
Args:
query_list: Collection of tuples (user_id, device_id, algorithm, count)
as passed to claim_e2e_one_time_keys.
Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed.
"""

# We can use RETURNING to do the fetch and DELETE in once step.
sql = """
DELETE FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
)
RETURNING key_id, key_json
"""

txn.execute(
sql,
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
)
otk_rows = list(txn)
if not otk_rows:
return []
WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES (?)
), ranked_keys AS (
SELECT
user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm)
)
DELETE FROM e2e_one_time_keys_json k
WHERE (user_id, device_id, algorithm, key_id) IN (
SELECT user_id, device_id, algorithm, key_id
FROM ranked_keys
WHERE r <= claim_count
)
RETURNING user_id, device_id, algorithm, key_id, key_json;
"""
txn.execute_values(sql, query_list)
otk_rows = cast(List[Tuple[str, str, str, str, str]], list(txn))

self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
return otk_rows


class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
Expand Down

0 comments on commit 62fc3bd

Please sign in to comment.