From 2fe286b99be5c1f4b922237191fd39da31f09c22 Mon Sep 17 00:00:00 2001 From: Igor Artemenko Date: Wed, 28 Dec 2022 19:52:52 +0000 Subject: [PATCH] Enable chunked downloads Before this change, whenever the user would download an attachment, the client (Pantalaimon) would store all the data in memory while processing it. This causes significant memory usage spikes when downloading multiple large files at once. The high memory usage is a major problem when running Pantalaimon in a memory-constrained setting. This change uses new matrix-nio functions to download and decrypt the attachments in chunks, which resolves the memory usage issues. --- pantalaimon/daemon.py | 49 ++++++++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 3264cab..e3486ae 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -36,10 +36,10 @@ LoginResponse, OlmTrustError, SendRetryError, - DownloadResponse, + StreamResponse, UploadResponse, ) -from nio.crypto import decrypt_attachment +from nio.crypto import async_decrypt_attachment from pantalaimon.client import ( SEARCH_TERMS_SCHEMA, @@ -549,7 +549,7 @@ async def forward_to_web( status=response.status, content_type=response.content_type, headers=CORS_HEADERS, - body=await response.read(), + body=response.content, ) except ClientConnectionError as e: return web.Response(status=500, text=str(e)) @@ -871,16 +871,18 @@ async def _decrypt_uri(self, content_uri, client): if not upload_info or not media_info: raise NotDecryptedAvailableError - response, decrypted_file = await self._load_decrypted_file( + response, decrypted_file_generator = await self._load_decrypted_file( media_info.mxc_server, media_info.mxc_path, upload_info.filename ) - if response is None and decrypted_file is None: + if response is None and decrypted_file_generator is None: raise NotDecryptedAvailableError - if not isinstance(response, DownloadResponse): + if not isinstance(response, StreamResponse): raise NotDecryptedAvailableError + decrypted_file = b"".join([chunk async for chunk in decrypted_file_generator]) + decrypted_upload, _ = await client.upload( data_provider=BufferedReader(BytesIO(decrypted_file)), content_type=upload_info.mimetype, @@ -1271,24 +1273,23 @@ async def _load_decrypted_file(self, server_name, media_id, file_name): return None, None client = next(iter(self.pan_clients.values())) + mxc = f"mxc://{server_name}/{media_id}" try: - response = await client.download(server_name, media_id, file_name) + response = await client.stream(mxc=mxc, filename=file_name) except ClientConnectionError as e: raise e - if not isinstance(response, DownloadResponse): + if not isinstance(response, StreamResponse): return response, None logger.info(f"Decrypting media {server_name}/{media_id}") - loop = asyncio.get_running_loop() - with concurrent.futures.ProcessPoolExecutor() as pool: - decrypted_file = await loop.run_in_executor( - pool, decrypt_attachment, response.body, key, hash, media_info.iv - ) + decrypted_file_generator = async_decrypt_attachment( + response.generator, key, hash, media_info.iv + ) - return response, decrypted_file + return response, decrypted_file_generator async def profile(self, request): access_token = self.get_access_token(request) @@ -1323,18 +1324,18 @@ async def download(self, request): file_name = request.match_info.get("file_name") try: - response, decrypted_file = await self._load_decrypted_file( + response, decrypted_file_generator = await self._load_decrypted_file( server_name, media_id, file_name ) - if response is None and decrypted_file is None: + if response is None and decrypted_file_generator is None: return await self.forward_to_web(request) except ClientConnectionError as e: return web.Response(status=500, text=str(e)) except KeyError: return await self.forward_to_web(request) - if not isinstance(response, DownloadResponse): + if not isinstance(response, StreamResponse): return web.Response( status=response.transport_response.status, content_type=response.transport_response.content_type, @@ -1342,12 +1343,16 @@ async def download(self, request): body=await response.transport_response.read(), ) - return web.Response( - status=response.transport_response.status, - content_type=response.transport_response.content_type, - headers=CORS_HEADERS, - body=decrypted_file, + stream_response = web.StreamResponse( + status=response.transport_response.status, headers=CORS_HEADERS ) + stream_response.content_length = response.transport_response.content_length + stream_response.content_type = response.content_type + await stream_response.prepare(request) + async for chunk in decrypted_file_generator: + await stream_response.write(chunk) + await stream_response.write_eof() + return stream_response async def well_known(self, _): """Intercept well-known requests