Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Premier essai (NE PAS MERGER!!!) #224

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions pyhilo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
API_NOTIFICATIONS_ENDPOINT,
API_REGISTRATION_ENDPOINT,
API_REGISTRATION_HEADERS,
AUTOMATION_CHALLENGE_ENDPOINT,
AUTOMATION_DEVICEHUB_ENDPOINT,
DEFAULT_STATE_FILE,
DEFAULT_USER_AGENT,
Expand All @@ -51,7 +52,7 @@
get_state,
set_state,
)
from pyhilo.websocket import WebsocketClient
from pyhilo.websocket import WebsocketClient, WebsocketManager


class API:
Expand Down Expand Up @@ -216,17 +217,24 @@ async def _async_request(
:rtype: dict[str, Any]
"""
kwargs.setdefault("headers", self.headers)
access_token = await self.async_get_access_token()

if endpoint.startswith(API_REGISTRATION_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS}
if endpoint.startswith(FB_INSTALL_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **FB_INSTALL_HEADERS}
if endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS}
if host == API_HOSTNAME:
access_token = await self.async_get_access_token()
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
kwargs["headers"]["Host"] = host

# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
Comment on lines +235 to +239
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the 'Ocp-Apim-Subscription-Key' header when the endpoint is 'AUTOMATION_CHALLENGE_ENDPOINT' might introduce issues if this header is required elsewhere in the system. It's crucial to document this change within the function's docstring or as comments explaining the broader context and its necessity.

Suggested change
# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
// Consider adding more comments or documentation here to explain why removing the header is necessary

Comment on lines +236 to +239
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling different authorization methods for various endpoints in the same function increases complexity and potential for errors. Consider extracting this logic into a separate function to improve readability and maintainability.

Suggested change
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
# Define a new function to handle request headers
def add_authorization_headers(endpoint, headers):
if endpoint.startswith(API_REGISTRATION_ENDPOINT):
headers.update(API_REGISTRATION_HEADERS)
elif endpoint.startswith(FB_INSTALL_ENDPOINT):
headers.update(FB_INSTALL_HEADERS)
elif endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
headers.update(ANDROID_CLIENT_HEADERS)
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
headers.pop("Ocp-Apim-Subscription-Key", None)

ic-dev21 marked this conversation as resolved.
Show resolved Hide resolved

data: dict[str, Any] = {}
url = parse.urljoin(f"https://{host}", endpoint)
if self.log_traces:
Expand Down Expand Up @@ -303,8 +311,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None:
LOG.info(
"401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}"
)
LOG.info(f"401 detected on {err.request_info.url}")
async with self._backoff_refresh_lock_ws:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.refresh_ws_token()
await self.get_websocket_params()
return

Expand Down Expand Up @@ -354,12 +363,23 @@ async def _async_post_init(self) -> None:
LOG.debug("Websocket postinit")
await self._get_fid()
await self._get_device_token()
await self.refresh_ws_token()
self.websocket = WebsocketClient(self)
# await self.refresh_ws_token()
# self.websocket = WebsocketClient(self)

# Initialize WebsocketManager ic-dev21
self.websocket_manager = WebsocketManager(
self.session, self.async_request, self._state_yaml, set_state
)
await self.websocket_manager.initialize_websockets()

# Create both websocket clients
self.websocket = WebsocketClient(self.websocket_manager.devicehub)
self.websocket2 = WebsocketClient(self.websocket_manager.challengehub)

async def refresh_ws_token(self) -> None:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.get_websocket_params()
"""Refresh the websocket token."""
await self.websocket_manager.refresh_token(self.websocket_manager.devicehub)
await self.websocket_manager.refresh_token(self.websocket_manager.challengehub)

async def post_devicehub_negociate(self) -> tuple[str, str]:
LOG.debug("Getting websocket url")
Expand Down
2 changes: 2 additions & 0 deletions pyhilo/const.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

# Automation server constant
AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub"
AUTOMATION_CHALLENGE_ENDPOINT: Final = "/ChallengeHub"


# Request constants
DEFAULT_USER_AGENT: Final = f"PyHilo/{PYHILO_VERSION} HomeAssistant/{homeassistant.core.__version__} aiohttp/{aiohttp.__version__} Python/{platform.python_version()}"
Expand Down
160 changes: 154 additions & 6 deletions pyhilo/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@
from enum import IntEnum
import json
from os import environ
from typing import TYPE_CHECKING, Any, Callable, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from urllib import parse
Comment on lines +10 to +11
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an unused import 'parse' which might have been intended for use in the module 'urllib'. Consider removing it to clean up unused imports and prevent confusion.

Suggested change
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from urllib import parse
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple


from aiohttp import ClientWebSocketResponse, WSMsgType
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType
from aiohttp.client_exceptions import (
ClientError,
ServerDisconnectedError,
WSServerHandshakeError,
)
from yarl import URL

from pyhilo.const import DEFAULT_USER_AGENT, LOG
from pyhilo.const import (
AUTOMATION_CHALLENGE_ENDPOINT,
AUTOMATION_DEVICEHUB_ENDPOINT,
DEFAULT_USER_AGENT,
LOG,
)
from pyhilo.exceptions import (
CannotConnectError,
ConnectionClosedError,
Expand Down Expand Up @@ -208,7 +214,7 @@ async def _async_send_json(self, payload: dict[str, Any]) -> None:

if self._api.log_traces:
LOG.debug(
f"[TRACE] Sending data to websocket server: {json.dumps(payload)}"
f"[TRACE] Sending data to websocket {self._api.endpoint} : {json.dumps(payload)}"
)
# Hilo added a control character (chr(30)) at the end of each payload they send.
# They also expect this char to be there at the end of every payload we send them.
Expand Down Expand Up @@ -263,7 +269,7 @@ async def async_connect(self) -> None:

LOG.info("Websocket: Connecting to server")
if self._api.log_traces:
LOG.debug(f"[TRACE] Websocket URL: {self._api.full_ws_url}")
LOG.debug(f"[TRACE] Websocket URL: {self._api.full_url}")
headers = {
"Sec-WebSocket-Extensions": "permessage-deflate; client_max_window_bits",
"Pragma": "no-cache",
Expand All @@ -281,7 +287,7 @@ async def async_connect(self) -> None:
try:
self._client = await self._api.session.ws_connect(
URL(
self._api.full_ws_url,
self._api.full_url,
encoded=True,
),
heartbeat=55,
Expand Down Expand Up @@ -376,3 +382,145 @@ async def async_invoke(
"type": inv_type,
}
)


@dataclass
class WebsocketConfig:
"""Configuration for a websocket connection"""

endpoint: str
url: Optional[str] = None
token: Optional[str] = None
connection_id: Optional[str] = None
full_url: Optional[str] = None
log_traces: bool = True
session: ClientSession | None = None


class WebsocketManager:
"""Manages multiple websocket connections for the Hilo API"""

def __init__(
self, session: ClientSession, async_request, state_yaml: str, set_state_callback
) -> None:
"""Initialize the websocket manager.

Args:
session: The aiohttp client session
async_request: The async request method from the API class
state_yaml: Path to the state file
set_state_callback: Callback to save state
"""
self.session = session
self.async_request = async_request
self._state_yaml = state_yaml
self._set_state = set_state_callback
self._shared_token = None # ic-dev21 need to share the token

# Initialize websocket configurations
self.devicehub = WebsocketConfig(
endpoint=AUTOMATION_DEVICEHUB_ENDPOINT, session=session
)
self.challengehub = WebsocketConfig(
endpoint=AUTOMATION_CHALLENGE_ENDPOINT, session=session
)

async def initialize_websockets(self) -> None:
"""Initialize both websocket connections"""
# ic-dev21 get token from device hub
await self.refresh_token(self.devicehub, get_new_token=True)
# ic-dev21 reuse it for challenge hub
await self.refresh_token(self.challengehub, get_new_token=True)

async def refresh_token(
self, config: WebsocketConfig, get_new_token: bool = True
) -> None:
"""Refresh token for a specific websocket configuration.

Args:
config: The websocket configuration to refresh
"""
if get_new_token:
config.url, self._shared_token = await self._negotiate(config)
config.token = self._shared_token
else:
# ic-dev21 reuse existing token but get new URL
config.url, _ = await self._negotiate(config)
config.token = self._shared_token

await self._get_websocket_params(config)

async def _negotiate(self, config: WebsocketConfig) -> Tuple[str, str]:
"""Negotiate websocket connection and get URL and token.

Args:
config: The websocket configuration to negotiate

Returns:
Tuple containing the websocket URL and access token
"""
LOG.debug(f"Getting websocket url for {config.endpoint}")
url = f"{config.endpoint}/negotiate"
LOG.debug(f"Negotiate URL is {url}")
Comment on lines +463 to +465
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String interpolation with f-strings for logging is generally fine for lower volume messages, but be mindful if this was a high-volume statement as it could have a performance impact. Consider lazy logging if performance becomes a concern.


resp = await self.async_request("post", url)
ws_url = resp.get("url")
ws_token = resp.get("accessToken")

# Save state
state_key = (
"websocket"
if config.endpoint == "AUTOMATION_DEVICEHUB_ENDPOINT"
else "websocket2"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for determining the state_key here is dependent on string comparison of endpoint names. This can become error-prone if endpoint names change. Consider using a more robust method of differentiating the websocket types, possibly by leveraging enumeration or a mapping approach.

Suggested change
state_key = (
"websocket"
if config.endpoint == "AUTOMATION_DEVICEHUB_ENDPOINT"
else "websocket2"
)
endpoint_mapping = {
"AUTOMATION_DEVICEHUB_ENDPOINT": "websocket",
"AUTOMATION_CHALLENGE_ENDPOINT": "websocket2"
}
state_key = endpoint_mapping.get(config.endpoint, "unknown")

await self._set_state(
self._state_yaml,
state_key,
{
"url": ws_url,
"token": ws_token,
},
)

return ws_url, ws_token

async def _get_websocket_params(self, config: WebsocketConfig) -> None:
"""Get websocket parameters including connection ID.

Args:
config: The websocket configuration to get parameters for
"""
uri = parse.urlparse(config.url)
LOG.debug(f"Getting websocket params for {config.endpoint}")
LOG.debug(f"Getting uri {uri}")

resp = await self.async_request(
"post",
f"{uri.path}negotiate?{uri.query}",
host=uri.netloc,
headers={
"authorization": f"Bearer {config.token}",
},
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of 'resp.get("connectionId")' without a default could lead to potential KeyError. Consider using a default value to safeguard against missing dictionary keys.

Suggested change
resp = await self.async_request(
"post",
f"{uri.path}negotiate?{uri.query}",
host=uri.netloc,
headers={
"authorization": f"Bearer {config.token}",
},
)
config.connection_id = resp.get("connectionId", "")


config.connection_id = resp.get("connectionId", "")
config.full_url = (
f"{config.url}&id={config.connection_id}&access_token={config.token}"
)
LOG.debug(f"Getting full ws URL {config.full_url}")

transport_dict = resp.get("availableTransports", [])
websocket_dict = {
"connection_id": config.connection_id,
"available_transports": transport_dict,
"full_url": config.full_url,
}

# Save state
state_key = (
"websocket"
if config.endpoint == "AUTOMATION_DEVICEHUB_ENDPOINT"
else "websocket2"
)
LOG.debug(f"Calling set_state {state_key}_params")
await self._set_state(self._state_yaml, state_key, websocket_dict)
Loading