|
10 | 10 | from aiohttp.client_exceptions import ClientConnectorError |
11 | 11 | from aiohttp.client_ws import ClientWebSocketResponse |
12 | 12 | from aiohttp.hdrs import AUTHORIZATION, CONTENT_TYPE |
13 | | -from aiohttp.http import WSMessage |
14 | 13 | from aiohttp.http_websocket import WSMsgType |
15 | 14 | from aiohttp.web_exceptions import HTTPBadGateway, HTTPUnauthorized |
16 | 15 |
|
| 16 | +from supervisor.utils.logging import AddonLoggerAdapter |
| 17 | + |
17 | 18 | from ..coresys import CoreSysAttributes |
18 | 19 | from ..exceptions import APIError, HomeAssistantAPIError, HomeAssistantAuthError |
19 | 20 | from ..utils.json import json_dumps |
@@ -179,23 +180,39 @@ async def _websocket_client(self) -> ClientWebSocketResponse: |
179 | 180 |
|
180 | 181 | async def _proxy_message( |
181 | 182 | self, |
182 | | - read_task: asyncio.Task, |
| 183 | + source: web.WebSocketResponse | ClientWebSocketResponse, |
183 | 184 | target: web.WebSocketResponse | ClientWebSocketResponse, |
| 185 | + logger: AddonLoggerAdapter, |
184 | 186 | ) -> None: |
185 | 187 | """Proxy a message from client to server or vice versa.""" |
186 | | - msg: WSMessage = read_task.result() |
187 | | - match msg.type: |
188 | | - case WSMsgType.TEXT: |
189 | | - await target.send_str(msg.data) |
190 | | - case WSMsgType.BINARY: |
191 | | - await target.send_bytes(msg.data) |
192 | | - case WSMsgType.CLOSE: |
193 | | - _LOGGER.debug("Received close message from WebSocket.") |
194 | | - await target.close() |
195 | | - case _: |
196 | | - raise TypeError( |
197 | | - f"Cannot proxy websocket message of unsupported type: {msg.type}" |
198 | | - ) |
| 188 | + while not source.closed and not target.closed: |
| 189 | + msg = await source.receive() |
| 190 | + match msg.type: |
| 191 | + case WSMsgType.TEXT: |
| 192 | + await target.send_str(msg.data) |
| 193 | + case WSMsgType.BINARY: |
| 194 | + await target.send_bytes(msg.data) |
| 195 | + case WSMsgType.CLOSE | WSMsgType.CLOSED: |
| 196 | + logger.debug( |
| 197 | + "Received WebSocket message type %r from %s.", |
| 198 | + msg.type, |
| 199 | + "add-on" if type(source) is web.WebSocketResponse else "Core", |
| 200 | + ) |
| 201 | + await target.close() |
| 202 | + case WSMsgType.CLOSING: |
| 203 | + pass |
| 204 | + case WSMsgType.ERROR: |
| 205 | + logger.warning( |
| 206 | + "Error WebSocket message received while proxying: %r", msg.data |
| 207 | + ) |
| 208 | + await target.close(code=source.close_code) |
| 209 | + case _: |
| 210 | + logger.warning( |
| 211 | + "Cannot proxy WebSocket message of unsupported type: %r", |
| 212 | + msg.type, |
| 213 | + ) |
| 214 | + await source.close() |
| 215 | + await target.close() |
199 | 216 |
|
200 | 217 | async def websocket(self, request: web.Request): |
201 | 218 | """Initialize a WebSocket API connection.""" |
@@ -255,48 +272,32 @@ async def websocket(self, request: web.Request): |
255 | 272 | except APIError: |
256 | 273 | return server |
257 | 274 |
|
258 | | - _LOGGER.info("Home Assistant WebSocket API request running") |
259 | | - try: |
260 | | - client_read: asyncio.Task | None = None |
261 | | - server_read: asyncio.Task | None = None |
262 | | - while not server.closed and not client.closed: |
263 | | - if not client_read: |
264 | | - client_read = self.sys_create_task(client.receive()) |
265 | | - if not server_read: |
266 | | - server_read = self.sys_create_task(server.receive()) |
267 | | - |
268 | | - # wait until data need to be processed |
269 | | - await asyncio.wait( |
270 | | - [client_read, server_read], return_when=asyncio.FIRST_COMPLETED |
271 | | - ) |
| 275 | + logger = AddonLoggerAdapter(_LOGGER, {"addon_name": addon_name}) |
| 276 | + logger.info("Home Assistant WebSocket API proxy running") |
272 | 277 |
|
273 | | - # server |
274 | | - if server_read.done() and not client.closed: |
275 | | - await self._proxy_message(server_read, client) |
276 | | - server_read = None |
| 278 | + client_task = self.sys_create_task(self._proxy_message(client, server, logger)) |
| 279 | + server_task = self.sys_create_task(self._proxy_message(server, client, logger)) |
277 | 280 |
|
278 | | - # client |
279 | | - if client_read.done() and not server.closed: |
280 | | - await self._proxy_message(client_read, server) |
281 | | - client_read = None |
| 281 | + # Typically, this will return with an empty pending set. However, if one of |
| 282 | + # the directions has an exception, make sure to close both connections and |
| 283 | + # wait for the other proxy task to exit gracefully. Using this over try-except |
| 284 | + # handling makes it easier to wait for the other direction to complete. |
| 285 | + _, pending = await asyncio.wait( |
| 286 | + (client_task, server_task), return_when=asyncio.FIRST_EXCEPTION |
| 287 | + ) |
282 | 288 |
|
283 | | - except asyncio.CancelledError: |
284 | | - pass |
| 289 | + if not client.closed: |
| 290 | + await client.close() |
| 291 | + if not server.closed: |
| 292 | + await server.close() |
285 | 293 |
|
286 | | - except (RuntimeError, ConnectionError, TypeError) as err: |
287 | | - _LOGGER.info("Home Assistant WebSocket API error: %s", err) |
288 | | - |
289 | | - finally: |
290 | | - if client_read and not client_read.done(): |
291 | | - client_read.cancel() |
292 | | - if server_read and not server_read.done(): |
293 | | - server_read.cancel() |
294 | | - |
295 | | - # close connections |
296 | | - if not client.closed: |
297 | | - await client.close() |
298 | | - if not server.closed: |
299 | | - await server.close() |
| 294 | + if pending: |
| 295 | + _, pending = await asyncio.wait( |
| 296 | + pending, timeout=10, return_when=asyncio.ALL_COMPLETED |
| 297 | + ) |
| 298 | + for task in pending: |
| 299 | + task.cancel() |
| 300 | + logger.critical("WebSocket proxy task: %s did not end gracefully", task) |
300 | 301 |
|
301 | | - _LOGGER.info("Home Assistant WebSocket API for %s closed", addon_name) |
| 302 | + logger.info("Home Assistant WebSocket API closed") |
302 | 303 | return server |
0 commit comments