Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Premier essai (NE PAS MERGER!!!) #224

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4e616c4
Premier essai
ic-dev21 Nov 22, 2024
08e40ce
Linting
ic-dev21 Nov 22, 2024
be41a7b
quelques progrès
ic-dev21 Nov 25, 2024
9246a21
Linting, imports oubliés
ic-dev21 Nov 25, 2024
97a8ca7
Keeping the right tokens for every websocket connection
Leicas Nov 25, 2024
e0026a1
Merge pull request #225 from Leicas/WebsocketManagerClass
ic-dev21 Nov 25, 2024
9dbeadf
Too many arguments
ic-dev21 Nov 25, 2024
561bd1b
ça fonctionne comme ça
Leicas Nov 26, 2024
74d43fa
Merge pull request #226 from Leicas/WebsocketManagerClass
ic-dev21 Nov 26, 2024
dfc49f1
Some Linting
ic-dev21 Nov 26, 2024
465e2ff
Some more linting
ic-dev21 Nov 26, 2024
8bfd6c0
better logs to see which ws is doing what.
Leicas Nov 26, 2024
5ef53c4
Merge pull request #227 from Leicas/WebsocketManagerClass
ic-dev21 Nov 26, 2024
bd6cd75
Merge branch 'WebsocketManagerClass' of https://github.com/dvd-dev/py…
ic-dev21 Nov 26, 2024
fb35d87
Revert "Some more linting"
ic-dev21 Nov 27, 2024
70cc6bb
More linting
ic-dev21 Nov 27, 2024
5659b41
Fix hilo_state.yaml bug
ic-dev21 Nov 28, 2024
0338018
Remove uneeded function
ic-dev21 Nov 28, 2024
15923f5
Remove unused import
ic-dev21 Dec 1, 2024
4e6ab3d
Linting
ic-dev21 Dec 1, 2024
cc86936
Linting
ic-dev21 Dec 1, 2024
83c3ad6
Some more linting
ic-dev21 Dec 2, 2024
414b256
Update api.py
ic-dev21 Dec 6, 2024
0c9fca9
Update api.py
ic-dev21 Dec 6, 2024
69fc037
Update api.py
ic-dev21 Jan 3, 2025
c59d424
Meilleure identification
ic-dev21 Jan 3, 2025
2e99e95
Logging + websocket en parallèle
ic-dev21 Jan 17, 2025
76781a7
Removing logging
ic-dev21 Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions pyhilo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
API_NOTIFICATIONS_ENDPOINT,
API_REGISTRATION_ENDPOINT,
API_REGISTRATION_HEADERS,
AUTOMATION_CHALLENGE_ENDPOINT,
AUTOMATION_DEVICEHUB_ENDPOINT,
DEFAULT_STATE_FILE,
DEFAULT_USER_AGENT,
Expand All @@ -51,7 +52,7 @@
get_state,
set_state,
)
from pyhilo.websocket import WebsocketClient
from pyhilo.websocket import WebsocketClient, WebsocketManager


class API:
Expand Down Expand Up @@ -216,17 +217,24 @@ async def _async_request(
:rtype: dict[str, Any]
"""
kwargs.setdefault("headers", self.headers)
access_token = await self.async_get_access_token()

if endpoint.startswith(API_REGISTRATION_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS}
if endpoint.startswith(FB_INSTALL_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **FB_INSTALL_HEADERS}
if endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS}
if host == API_HOSTNAME:
access_token = await self.async_get_access_token()
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
kwargs["headers"]["Host"] = host

# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
Comment on lines +238 to +242
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the 'Ocp-Apim-Subscription-Key' header when the endpoint is 'AUTOMATION_CHALLENGE_ENDPOINT' might introduce issues if this header is required elsewhere in the system. It's crucial to document this change within the function's docstring or as comments explaining the broader context and its necessity.

Suggested change
# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
// Consider adding more comments or documentation here to explain why removing the header is necessary

Comment on lines +239 to +242
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling different authorization methods for various endpoints in the same function increases complexity and potential for errors. Consider extracting this logic into a separate function to improve readability and maintainability.

Suggested change
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
# Define a new function to handle request headers
def add_authorization_headers(endpoint, headers):
if endpoint.startswith(API_REGISTRATION_ENDPOINT):
headers.update(API_REGISTRATION_HEADERS)
elif endpoint.startswith(FB_INSTALL_ENDPOINT):
headers.update(FB_INSTALL_HEADERS)
elif endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
headers.update(ANDROID_CLIENT_HEADERS)
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
headers.pop("Ocp-Apim-Subscription-Key", None)

ic-dev21 marked this conversation as resolved.
Show resolved Hide resolved

data: dict[str, Any] = {}
url = parse.urljoin(f"https://{host}", endpoint)
if self.log_traces:
Expand Down Expand Up @@ -303,8 +311,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None:
LOG.info(
"401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}"
)
LOG.info(f"401 detected on {err.request_info.url}")
async with self._backoff_refresh_lock_ws:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.refresh_ws_token()
await self.get_websocket_params()
return

Expand Down Expand Up @@ -354,12 +363,23 @@ async def _async_post_init(self) -> None:
LOG.debug("Websocket postinit")
await self._get_fid()
await self._get_device_token()
await self.refresh_ws_token()
self.websocket = WebsocketClient(self)
# await self.refresh_ws_token()
# self.websocket = WebsocketClient(self)

# Initialize WebsocketManager ic-dev21
self.websocket_manager = WebsocketManager(
self.session, self.async_request, self._state_yaml, set_state
)
await self.websocket_manager.initialize_websockets()

# Create both websocket clients
self.websocket = WebsocketClient(self.websocket_manager.devicehub)
self.websocket2 = WebsocketClient(self.websocket_manager.challengehub)

async def refresh_ws_token(self) -> None:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.get_websocket_params()
"""Refresh the websocket token."""
await self.websocket_manager.refresh_token(self.websocket_manager.devicehub)
await self.websocket_manager.refresh_token(self.websocket_manager.challengehub)

async def post_devicehub_negociate(self) -> tuple[str, str]:
LOG.debug("Getting websocket url")
Expand Down
2 changes: 2 additions & 0 deletions pyhilo/const.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

# Automation server constant
AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub"
AUTOMATION_CHALLENGE_ENDPOINT: Final = "/ChallengeHub"


# Request constants
DEFAULT_USER_AGENT: Final = f"PyHilo/{PYHILO_VERSION} HomeAssistant/{homeassistant.core.__version__} aiohttp/{aiohttp.__version__} Python/{platform.python_version()}"
Expand Down
160 changes: 154 additions & 6 deletions pyhilo/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@
from enum import IntEnum
import json
from os import environ
from typing import TYPE_CHECKING, Any, Callable, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from urllib import parse
Comment on lines +10 to +11
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an unused import 'parse' which might have been intended for use in the module 'urllib'. Consider removing it to clean up unused imports and prevent confusion.

Suggested change
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from urllib import parse
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple


from aiohttp import ClientWebSocketResponse, WSMsgType
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType
from aiohttp.client_exceptions import (
ClientError,
ServerDisconnectedError,
WSServerHandshakeError,
)
from yarl import URL

from pyhilo.const import DEFAULT_USER_AGENT, LOG
from pyhilo.const import (
AUTOMATION_CHALLENGE_ENDPOINT,
AUTOMATION_DEVICEHUB_ENDPOINT,
DEFAULT_USER_AGENT,
LOG,
)
from pyhilo.exceptions import (
CannotConnectError,
ConnectionClosedError,
Expand Down Expand Up @@ -208,7 +214,7 @@ async def _async_send_json(self, payload: dict[str, Any]) -> None:

if self._api.log_traces:
LOG.debug(
f"[TRACE] Sending data to websocket server: {json.dumps(payload)}"
f"[TRACE] Sending data to websocket {self._api.endpoint} : {json.dumps(payload)}"
)
# Hilo added a control character (chr(30)) at the end of each payload they send.
# They also expect this char to be there at the end of every payload we send them.
Expand Down Expand Up @@ -263,7 +269,7 @@ async def async_connect(self) -> None:

LOG.info("Websocket: Connecting to server")
if self._api.log_traces:
LOG.debug(f"[TRACE] Websocket URL: {self._api.full_ws_url}")
LOG.debug(f"[TRACE] Websocket URL: {self._api.full_url}")
headers = {
"Sec-WebSocket-Extensions": "permessage-deflate; client_max_window_bits",
"Pragma": "no-cache",
Expand All @@ -281,7 +287,7 @@ async def async_connect(self) -> None:
try:
self._client = await self._api.session.ws_connect(
URL(
self._api.full_ws_url,
self._api.full_url,
encoded=True,
),
heartbeat=55,
Expand Down Expand Up @@ -376,3 +382,145 @@ async def async_invoke(
"type": inv_type,
}
)


@dataclass
class WebsocketConfig:
"""Configuration for a websocket connection"""

endpoint: str
url: Optional[str] = None
token: Optional[str] = None
connection_id: Optional[str] = None
full_url: Optional[str] = None
log_traces: bool = True
session: ClientSession | None = None


class WebsocketManager:
"""Manages multiple websocket connections for the Hilo API"""

def __init__(
self, session: ClientSession, async_request, state_yaml: str, set_state_callback
) -> None:
"""Initialize the websocket manager.

Args:
session: The aiohttp client session
async_request: The async request method from the API class
state_yaml: Path to the state file
set_state_callback: Callback to save state
"""
self.session = session
self.async_request = async_request
self._state_yaml = state_yaml
self._set_state = set_state_callback
self._shared_token = None # ic-dev21 need to share the token

# Initialize websocket configurations
self.devicehub = WebsocketConfig(
endpoint=AUTOMATION_DEVICEHUB_ENDPOINT, session=session
)
self.challengehub = WebsocketConfig(
endpoint=AUTOMATION_CHALLENGE_ENDPOINT, session=session
)

async def initialize_websockets(self) -> None:
"""Initialize both websocket connections"""
# ic-dev21 get token from device hub
await self.refresh_token(self.devicehub, get_new_token=True)
# ic-dev21 reuse it for challenge hub
await self.refresh_token(self.challengehub, get_new_token=True)

async def refresh_token(
self, config: WebsocketConfig, get_new_token: bool = True
) -> None:
"""Refresh token for a specific websocket configuration.

Args:
config: The websocket configuration to refresh
"""
if get_new_token:
config.url, self._shared_token = await self._negotiate(config)
config.token = self._shared_token
else:
# ic-dev21 reuse existing token but get new URL
config.url, _ = await self._negotiate(config)
config.token = self._shared_token

await self._get_websocket_params(config)

async def _negotiate(self, config: WebsocketConfig) -> Tuple[str, str]:
"""Negotiate websocket connection and get URL and token.

Args:
config: The websocket configuration to negotiate

Returns:
Tuple containing the websocket URL and access token
"""
LOG.debug(f"Getting websocket url for {config.endpoint}")
url = f"{config.endpoint}/negotiate"
LOG.debug(f"Negotiate URL is {url}")
Comment on lines +465 to +467
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String interpolation with f-strings for logging is generally fine for lower volume messages, but be mindful if this was a high-volume statement as it could have a performance impact. Consider lazy logging if performance becomes a concern.


resp = await self.async_request("post", url)
ws_url = resp.get("url")
ws_token = resp.get("accessToken")

# Save state
state_key = (
"websocket"
if config.endpoint == "AUTOMATION_DEVICEHUB_ENDPOINT"
else "websocket2"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for determining the state_key here is dependent on string comparison of endpoint names. This can become error-prone if endpoint names change. Consider using a more robust method of differentiating the websocket types, possibly by leveraging enumeration or a mapping approach.

Suggested change
state_key = (
"websocket"
if config.endpoint == "AUTOMATION_DEVICEHUB_ENDPOINT"
else "websocket2"
)
endpoint_mapping = {
"AUTOMATION_DEVICEHUB_ENDPOINT": "websocket",
"AUTOMATION_CHALLENGE_ENDPOINT": "websocket2"
}
state_key = endpoint_mapping.get(config.endpoint, "unknown")

await self._set_state(
self._state_yaml,
state_key,
{
"url": ws_url,
"token": ws_token,
},
)

return ws_url, ws_token

async def _get_websocket_params(self, config: WebsocketConfig) -> None:
"""Get websocket parameters including connection ID.

Args:
config: The websocket configuration to get parameters for
"""
uri = parse.urlparse(config.url)
LOG.debug(f"Getting websocket params for {config.endpoint}")
LOG.debug(f"Getting uri {uri}")

resp = await self.async_request(
"post",
f"{uri.path}negotiate?{uri.query}",
host=uri.netloc,
headers={
"authorization": f"Bearer {config.token}",
},
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of 'resp.get("connectionId")' without a default could lead to potential KeyError. Consider using a default value to safeguard against missing dictionary keys.

Suggested change
resp = await self.async_request(
"post",
f"{uri.path}negotiate?{uri.query}",
host=uri.netloc,
headers={
"authorization": f"Bearer {config.token}",
},
)
config.connection_id = resp.get("connectionId", "")


config.connection_id = resp.get("connectionId", "")
config.full_url = (
f"{config.url}&id={config.connection_id}&access_token={config.token}"
)
LOG.debug(f"Getting full ws URL {config.full_url}")

transport_dict = resp.get("availableTransports", [])
websocket_dict = {
"connection_id": config.connection_id,
"available_transports": transport_dict,
"full_url": config.full_url,
}

# Save state
state_key = (
"websocket"
if config.endpoint == "AUTOMATION_DEVICEHUB_ENDPOINT"
else "websocket2"
)
LOG.debug(f"Calling set_state {state_key}_params")
await self._set_state(self._state_yaml, state_key, websocket_dict)
Loading