From d8427643006902702e18f162e019aa6c798760e1 Mon Sep 17 00:00:00 2001 From: spacemanspiff2007 <10754716+spacemanspiff2007@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:52:18 +0100 Subject: [PATCH] . --- src/HABApp/core/connections/manager.py | 17 +++++++++++++---- src/HABApp/core/connections/plugin_callback.py | 12 ++++++++---- .../core/connections/status_transitions.py | 8 ++++---- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/HABApp/core/connections/manager.py b/src/HABApp/core/connections/manager.py index 2cb523da..f1fd3e1f 100644 --- a/src/HABApp/core/connections/manager.py +++ b/src/HABApp/core/connections/manager.py @@ -1,13 +1,17 @@ from __future__ import annotations import asyncio -from typing import Final, TypeVar +from typing import TYPE_CHECKING, Final, TypeVar import HABApp from HABApp.core.connections import BaseConnection from HABApp.core.connections._definitions import connection_log +if TYPE_CHECKING: + from collections.abc import Generator + + T = TypeVar('T', bound=BaseConnection) @@ -16,16 +20,21 @@ def __init__(self) -> None: self.connections: dict[str, BaseConnection] = {} def add(self, connection: T) -> T: - assert connection.name not in self.connections + if connection.name in self.connections: + msg = f'Connection {connection.name:s} already exists!' + raise ValueError(msg) + self.connections[connection.name] = connection connection_log.debug(f'Added {connection.name:s}') - return connection def get(self, name: str) -> BaseConnection: return self.connections[name] - def remove(self, name): + def get_names(self) -> Generator[str, None, None]: + yield from self.connections.keys() + + def remove(self, name: str) -> None: con = self.get(name) if not con.is_shutdown: raise ValueError() diff --git a/src/HABApp/core/connections/plugin_callback.py b/src/HABApp/core/connections/plugin_callback.py index a83e2f60..fac9eef3 100644 --- a/src/HABApp/core/connections/plugin_callback.py +++ b/src/HABApp/core/connections/plugin_callback.py @@ -6,6 +6,8 @@ from inspect import getmembers, iscoroutinefunction, signature from typing import TYPE_CHECKING, Any +from typing_extensions import Self + from ._definitions import ConnectionStatus @@ -54,9 +56,10 @@ async def run(self, connection: BaseConnection, context: Any): return await self.coro(**kwargs) @staticmethod - def _get_coro_kwargs(plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]): + def _get_coro_kwargs(plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]) -> tuple[str, ...]: if not iscoroutinefunction(coro): - raise ValueError(f'Coroutine function expected for {plugin.plugin_name}.{coro.__name__}') + msg = f'Coroutine function expected for {plugin.plugin_name}.{coro.__name__}' + raise ValueError(msg) sig = signature(coro) @@ -65,9 +68,10 @@ def _get_coro_kwargs(plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitab if name in ('connection', 'context'): kwargs.append(name) else: - raise ValueError(f'Invalid parameter name "{name:s}" for {plugin.plugin_name}.{coro.__name__}') + msg = f'Invalid parameter name "{name:s}" for {plugin.plugin_name}.{coro.__name__}' + raise ValueError(msg) return tuple(kwargs) @classmethod - def create(cls, plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]): + def create(cls, plugin: BaseConnectionPlugin, coro: Callable[[...], Awaitable]) -> Self: return cls(plugin, coro, cls._get_coro_kwargs(plugin, coro)) diff --git a/src/HABApp/core/connections/status_transitions.py b/src/HABApp/core/connections/status_transitions.py index 7d1848ac..da0df692 100644 --- a/src/HABApp/core/connections/status_transitions.py +++ b/src/HABApp/core/connections/status_transitions.py @@ -77,11 +77,11 @@ def _next_step(self) -> ConnectionStatus: return transitions.get(status) def __repr__(self) -> str: - return f'<{self.__class__.__name__} {self.status} ' \ - f'[{"x" if self.error else " "}] Error, ' \ - f'[{"x" if self.setup else " "}] Setup>' + return (f'<{self.__class__.__name__} {self.status} ' + f'[{"x" if self.error else " "}] Error, ' + f'[{"x" if self.setup else " "}] Setup>') - def __eq__(self, other: ConnectionStatus): + def __eq__(self, other: ConnectionStatus) -> bool: if not isinstance(other, ConnectionStatus): return NotImplemented return self.status == other