|
3 | 3 |
|
4 | 4 | import hashlib |
5 | 5 | import json |
| 6 | +from pathlib import Path |
6 | 7 | from typing import Any |
7 | 8 |
|
8 | 9 | import voluptuous as vol |
|
12 | 13 |
|
13 | 14 | from . import DOMAIN |
14 | 15 | from .ssh_discovery import discover_ssh_hosts, guess_local_network |
| 16 | +from .util import resolve_private_key_path |
15 | 17 |
|
16 | 18 | DEFAULT_INTERVAL = 30 |
17 | 19 |
|
@@ -95,35 +97,44 @@ async def async_step_user(self, user_input: dict[str, Any] | None = None): |
95 | 97 | } |
96 | 98 | if user_input.get("password"): |
97 | 99 | server["password"] = user_input["password"] |
98 | | - if user_input.get("key"): |
99 | | - server["key"] = user_input["key"] |
100 | | - self._servers.append(server) |
101 | | - if user_input.get("add_another"): |
102 | | - hosts = await self._get_discovered_hosts() |
103 | | - return self.async_show_form( |
104 | | - step_id="user", |
105 | | - data_schema=_build_server_schema( |
106 | | - hosts, |
107 | | - include_interval=False, |
108 | | - interval_default=self._interval, |
109 | | - default_name=vol.UNDEFINED, |
110 | | - ), |
| 100 | + key_input = user_input.get("key") |
| 101 | + if key_input: |
| 102 | + resolved = resolve_private_key_path(self.hass, key_input) |
| 103 | + if not Path(resolved).exists(): |
| 104 | + errors["key"] = "key_missing" |
| 105 | + defaults = user_input |
| 106 | + else: |
| 107 | + server["key"] = resolved |
| 108 | + if errors: |
| 109 | + defaults = user_input |
| 110 | + else: |
| 111 | + self._servers.append(server) |
| 112 | + if user_input.get("add_another"): |
| 113 | + hosts = await self._get_discovered_hosts() |
| 114 | + return self.async_show_form( |
| 115 | + step_id="user", |
| 116 | + data_schema=_build_server_schema( |
| 117 | + hosts, |
| 118 | + include_interval=False, |
| 119 | + interval_default=self._interval, |
| 120 | + default_name=vol.UNDEFINED, |
| 121 | + ), |
| 122 | + ) |
| 123 | + |
| 124 | + hosts_for_id = ",".join(sorted(server["host"] for server in self._servers)) |
| 125 | + unique_id = hashlib.sha256(hosts_for_id.encode()).hexdigest() |
| 126 | + await self.async_set_unique_id(unique_id) |
| 127 | + self._abort_if_unique_id_configured() |
| 128 | + data = { |
| 129 | + "interval": self._interval, |
| 130 | + "servers_json": json.dumps(self._servers), |
| 131 | + } |
| 132 | + title = ( |
| 133 | + self._servers[0]["name"] |
| 134 | + if len(self._servers) == 1 |
| 135 | + else "VServer SSH Stats" |
111 | 136 | ) |
112 | | - |
113 | | - hosts_for_id = ",".join(sorted(server["host"] for server in self._servers)) |
114 | | - unique_id = hashlib.sha256(hosts_for_id.encode()).hexdigest() |
115 | | - await self.async_set_unique_id(unique_id) |
116 | | - self._abort_if_unique_id_configured() |
117 | | - data = { |
118 | | - "interval": self._interval, |
119 | | - "servers_json": json.dumps(self._servers), |
120 | | - } |
121 | | - title = ( |
122 | | - self._servers[0]["name"] |
123 | | - if len(self._servers) == 1 |
124 | | - else "VServer SSH Stats" |
125 | | - ) |
126 | | - return self.async_create_entry(title=title, data=data) |
| 137 | + return self.async_create_entry(title=title, data=data) |
127 | 138 |
|
128 | 139 | hosts = await self._get_discovered_hosts() |
129 | 140 | default_name = self._discovered_name if first_server else vol.UNDEFINED |
@@ -213,6 +224,11 @@ def __init__(self, config_entry: config_entries.ConfigEntry) -> None: |
213 | 224 | ) |
214 | 225 | except ValueError: |
215 | 226 | self._existing_servers = [] |
| 227 | + hass = config_entry.hass |
| 228 | + for server in self._existing_servers: |
| 229 | + key = resolve_private_key_path(hass, server.get("key")) if hass else server.get("key") |
| 230 | + if key: |
| 231 | + server["key"] = key |
216 | 232 | self._pending_servers: list[dict[str, Any]] = [] |
217 | 233 |
|
218 | 234 | async def async_step_init(self, user_input: dict[str, Any] | None = None): |
@@ -271,23 +287,32 @@ async def async_step_servers(self, user_input: dict[str, Any] | None = None): |
271 | 287 | } |
272 | 288 | if user_input.get("password"): |
273 | 289 | server["password"] = user_input["password"] |
274 | | - if user_input.get("key"): |
275 | | - server["key"] = user_input["key"] |
276 | | - self._pending_servers.append(server) |
277 | | - if user_input.get("add_another"): |
278 | | - hosts = await self._get_discovered_hosts() |
279 | | - return self.async_show_form( |
280 | | - step_id="servers", |
281 | | - data_schema=_build_server_schema( |
282 | | - hosts, |
283 | | - include_interval=False, |
284 | | - interval_default=self._interval, |
285 | | - default_name=vol.UNDEFINED, |
286 | | - ), |
287 | | - ) |
288 | | - |
289 | | - self._update_entry(self._pending_servers) |
290 | | - return self.async_create_entry(title="", data={}) |
| 290 | + key_input = user_input.get("key") |
| 291 | + if key_input: |
| 292 | + resolved = resolve_private_key_path(self.hass, key_input) |
| 293 | + if not Path(resolved).exists(): |
| 294 | + errors["key"] = "key_missing" |
| 295 | + defaults = user_input |
| 296 | + else: |
| 297 | + server["key"] = resolved |
| 298 | + if errors: |
| 299 | + defaults = user_input |
| 300 | + else: |
| 301 | + self._pending_servers.append(server) |
| 302 | + if user_input.get("add_another"): |
| 303 | + hosts = await self._get_discovered_hosts() |
| 304 | + return self.async_show_form( |
| 305 | + step_id="servers", |
| 306 | + data_schema=_build_server_schema( |
| 307 | + hosts, |
| 308 | + include_interval=False, |
| 309 | + interval_default=self._interval, |
| 310 | + default_name=vol.UNDEFINED, |
| 311 | + ), |
| 312 | + ) |
| 313 | + |
| 314 | + self._update_entry(self._pending_servers) |
| 315 | + return self.async_create_entry(title="", data={}) |
291 | 316 |
|
292 | 317 | hosts = await self._get_discovered_hosts() |
293 | 318 | return self.async_show_form( |
|
0 commit comments