diff --git a/custom_components/entity_tz/__init__.py b/custom_components/entity_tz/__init__.py index 99e8d4d..86d3221 100644 --- a/custom_components/entity_tz/__init__.py +++ b/custom_components/entity_tz/__init__.py @@ -20,7 +20,7 @@ from homeassistant.util import dt as dt_util from .const import DOMAIN -from .helpers import etz_data, get_location, get_tz, init_etz_data, signal +from .helpers import etz_data, get_loc, get_tz, init_etz_data, signal CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) PLATFORMS = [Platform.BINARY_SENSOR, Platform.SENSOR] @@ -34,8 +34,9 @@ async def async_setup(hass: HomeAssistant, _: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up config entry.""" - etz_data(hass).loc_users[entry.entry_id] = 0 - etz_data(hass).tz_users[entry.entry_id] = 0 + etzd = etz_data(hass) + etzd.loc_users[entry.entry_id] = 0 + etzd.tz_users[entry.entry_id] = 0 await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) @@ -63,11 +64,11 @@ async def update_from_entity(event: Event | None = None) -> None: or new_state.attributes.get(ATTR_LONGITUDE) != old_state.attributes.get(ATTR_LONGITUDE) ): - if etz_data(hass).loc_users[entry.entry_id]: - entity_loc = await get_location(hass, new_state) + if etzd.loc_available and etzd.loc_users[entry.entry_id]: + entity_loc = await get_loc(hass, new_state) else: entity_loc = None - if etz_data(hass).tz_users[entry.entry_id]: + if etzd.tz_users[entry.entry_id]: entity_tz = get_tz(hass, new_state) else: entity_tz = None @@ -93,8 +94,9 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" res = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - loc_users = etz_data(hass).loc_users.pop(entry.entry_id) - tz_uers = etz_data(hass).tz_users.pop(entry.entry_id) + etzd = etz_data(hass) + loc_users = etzd.loc_users.pop(entry.entry_id) + tz_uers = etzd.tz_users.pop(entry.entry_id) assert not loc_users assert not tz_uers return res diff --git a/custom_components/entity_tz/const.py b/custom_components/entity_tz/const.py index 27049c9..32c0715 100644 --- a/custom_components/entity_tz/const.py +++ b/custom_components/entity_tz/const.py @@ -10,9 +10,6 @@ LOCAL_TIME_ICON = "mdi:account-clock" TIME_ZONE_ICON = "mdi:map-clock" -NOM_TIMEOUT = 10 -NOM_WAIT = 1.5 - ATTR_COUNTRY_CODE = "country_code" ATTR_UTC_OFFSET = "utc_offset" diff --git a/custom_components/entity_tz/helpers.py b/custom_components/entity_tz/helpers.py index 9d25922..3e4c55f 100644 --- a/custom_components/entity_tz/helpers.py +++ b/custom_components/entity_tz/helpers.py @@ -1,17 +1,13 @@ """Entity Time Zone Sensor Helpers.""" from __future__ import annotations -import asyncio from collections.abc import Container, Mapping from dataclasses import dataclass from datetime import tzinfo from enum import Enum, auto -import logging from typing import Any, cast from zoneinfo import available_timezones -from geopy.adapters import AioHTTPAdapter -from geopy.geocoders import Nominatim from geopy.location import Location from timezonefinder import TimezoneFinder @@ -31,10 +27,6 @@ STATE_UNAVAILABLE, ) from homeassistant.core import Event, HomeAssistant, State, callback, split_entity_id -from homeassistant.helpers.aiohttp_client import ( - SERVER_SOFTWARE, - async_get_clientsession, -) from homeassistant.helpers.device_registry import DeviceEntryType # Device Info moved to device_registry in 2023.9 @@ -48,20 +40,20 @@ from homeassistant.helpers.event import async_track_time_change from homeassistant.util import dt as dt_util -from .const import DOMAIN, NOM_TIMEOUT, NOM_WAIT, SIG_ENTITY_CHANGED +from .const import DOMAIN, SIG_ENTITY_CHANGED +from .nominatim import get_location, init_nominatim -_LOGGER = logging.getLogger(__name__) +_ALWAYS_DISABLED_ENTITIES = ("address", "country", "diff_country", "diff_time") @dataclass(init=False) class ETZData: """Entity Time Zone integration data.""" - tzs: set[str] - nominatim: Nominatim + loc_available: bool loc_users: dict[str, int] - query_lock: asyncio.Lock tzf: TimezoneFinder + tzs: set[str] tz_users: dict[str, int] zones: list[str] @@ -75,29 +67,19 @@ async def init_etz_data(hass: HomeAssistant) -> None: """Initialize integration's data.""" if DOMAIN in hass.data: return - hass.data[DOMAIN] = ETZData() - etz_data(hass).loc_users = {} - etz_data(hass).tz_users = {} - - nominatim = Nominatim( - user_agent=SERVER_SOFTWARE, - timeout=NOM_TIMEOUT, - adapter_factory=lambda proxies, ssl_context: AioHTTPAdapter( - proxies=proxies, ssl_context=ssl_context - ), - ) - nominatim.adapter.__dict__["session"] = async_get_clientsession(hass) - etz_data(hass).nominatim = nominatim - etz_data(hass).query_lock = asyncio.Lock() + hass.data[DOMAIN] = etzd = ETZData() + etzd.loc_available = init_nominatim(hass) + etzd.loc_users = {} + etzd.tz_users = {} def init_tz_data() -> None: - """Initialize time zone data.""" - - # This must be done in an executor since zoneinfo.available_timezones and the - # TimezoneFinder constructor both do file I/O. + """Initialize time zone data. - etz_data(hass).tzs = available_timezones() - etz_data(hass).tzf = TimezoneFinder() + This must be done in an executor since TimezoneFinder constructor and + zoneinfo.available_timezones both do file I/O. + """ + etzd.tzf = TimezoneFinder() + etzd.tzs = available_timezones() await hass.async_add_executor_job(init_tz_data) @@ -108,7 +90,7 @@ def update_zones(_: Event | None = None) -> None: for state in hass.states.async_all(ZONE_DOMAIN): if get_tz(hass, state) != dt_util.DEFAULT_TIME_ZONE: zones.append(state.entity_id) - etz_data(hass).zones = zones + etzd.zones = zones @callback def zones_filter(event: Event) -> bool: @@ -121,7 +103,7 @@ def zones_filter(event: Event) -> bool: def get_tz(hass: HomeAssistant, state: State | None) -> tzinfo | str | None: - """Get time zone from latitude & longitude from state.""" + """Get time zone from entity state.""" if not state: return STATE_UNAVAILABLE if state.domain in (PERSON_DOMAIN, DT_DOMAIN) and state.state == STATE_HOME: @@ -136,10 +118,8 @@ def get_tz(hass: HomeAssistant, state: State | None) -> tzinfo | str | None: return dt_util.get_time_zone(tz_name) -async def get_location( - hass: HomeAssistant, state: State | None -) -> Location | str | None: - """Get address from latitude & longitude.""" +async def get_loc(hass: HomeAssistant, state: State | None) -> Location | str | None: + """Get location data from entity state.""" if state is None: return STATE_UNAVAILABLE lat = state.attributes.get(ATTR_LATITUDE) @@ -147,22 +127,7 @@ async def get_location( if lat is None or lng is None: return STATE_UNAVAILABLE - lock = etz_data(hass).query_lock - - async def limit_rate() -> None: - """Hold the lock to limit calls to server.""" - await asyncio.sleep(NOM_WAIT) - lock.release() - - coordinates = f"{lat}, {lng}" - await lock.acquire() - try: - return await etz_data(hass).nominatim.reverse(coordinates) - except Exception as exc: # pylint: disable=broad-exception-caught - _LOGGER.error("While getting address & country code: %s", exc) - return None - finally: - hass.async_create_background_task(limit_rate(), "Limit nominatim query rate") + return await get_location(hass, lat, lng) def signal(entry: ConfigEntry) -> str: @@ -170,9 +135,6 @@ def signal(entry: ConfigEntry) -> str: return f"{SIG_ENTITY_CHANGED}-{entry.entry_id}" -_ALWAYS_DISABLED_ENTITIES = ("address", "country", "diff_country", "diff_time") - - def _enable_entity(key: str, entry_data: Mapping[str, Any]) -> bool: """Determine if entity should be enabled by default.""" if key in _ALWAYS_DISABLED_ENTITIES: @@ -252,10 +214,11 @@ def entity_changed( loc_user = ETZSource.LOC in self._sources tz_user = ETZSource.TZ in self._sources if loc_user or tz_user: + etzd = etz_data(self.hass) if loc_user: - etz_data(self.hass).loc_users[config_entry.entry_id] += 1 + etzd.loc_users[config_entry.entry_id] += 1 if tz_user: - etz_data(self.hass).tz_users[config_entry.entry_id] += 1 + etzd.tz_users[config_entry.entry_id] += 1 self.async_on_remove( async_dispatcher_connect( self.hass, signal(config_entry), entity_changed @@ -278,10 +241,11 @@ def update(_: Any) -> None: async def async_will_remove_from_hass(self) -> None: """Run when entity will be removed from hass.""" config_entry = cast(ConfigEntry, self.platform.config_entry) + etzd = etz_data(self.hass) if ETZSource.LOC in self._sources: - etz_data(self.hass).loc_users[config_entry.entry_id] -= 1 + etzd.loc_users[config_entry.entry_id] -= 1 if ETZSource.TZ in self._sources: - etz_data(self.hass).tz_users[config_entry.entry_id] -= 1 + etzd.tz_users[config_entry.entry_id] -= 1 @property def _sources_valid(self) -> bool: diff --git a/custom_components/entity_tz/manifest.json b/custom_components/entity_tz/manifest.json index b6a71d1..066122d 100644 --- a/custom_components/entity_tz/manifest.json +++ b/custom_components/entity_tz/manifest.json @@ -7,6 +7,7 @@ "documentation": "https://github.com/pnbruckner/ha-entity-tz/blob/master/README.md", "iot_class": "calculated", "issue_tracker": "https://github.com/pnbruckner/ha-entity-tz/issues", + "loggers": ["geopy"], "requirements": ["timezonefinder==5.2.0"], "version": "1.0.0b9" } diff --git a/custom_components/entity_tz/nominatim.py b/custom_components/entity_tz/nominatim.py new file mode 100644 index 0000000..27d2b37 --- /dev/null +++ b/custom_components/entity_tz/nominatim.py @@ -0,0 +1,96 @@ +"""Nominatim from geopy helper.""" + +import asyncio +from dataclasses import dataclass, fields +import logging + +from geopy.adapters import AioHTTPAdapter +from geopy.exc import GeocoderRateLimited +from geopy.geocoders import Nominatim +from geopy.location import Location + +from homeassistant.core import HomeAssistant +from homeassistant.helpers.aiohttp_client import ( + SERVER_SOFTWARE, + async_get_clientsession, +) + +NOMINATIM_DATA = "sharable_nominatim" +TIMEOUT = 10 +INITIAL_WAIT = 1.5 +CACHE_SIZE = 25 + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class NomData: + """Nomination data.""" + + cache: dict[str, Location | None] + lock: asyncio.Lock + nominatim: Nominatim + wait: float + + +def _save_nom_data(hass: HomeAssistant, nom_data: NomData) -> None: + """Save nominatim data to hass.data.""" + hass.data[NOMINATIM_DATA] = { + field.name: getattr(nom_data, field.name) for field in fields(nom_data) + } + + +def init_nominatim(hass: HomeAssistant) -> bool: + """Initialize sharable Nominatim object.""" + if NOMINATIM_DATA in hass.data: + nom_data = hass.data[NOMINATIM_DATA] + try: + for field in fields(NomData): + if not isinstance(nom_data[field.name], field.type): + raise TypeError + except (KeyError, TypeError): + msg = f"Unexpected data in hass.data[{NOMINATIM_DATA!r}]: {nom_data}" + _LOGGER.error(msg) + return False + return True + + nominatim = Nominatim( + user_agent=SERVER_SOFTWARE, + timeout=TIMEOUT, + adapter_factory=lambda proxies, ssl_context: AioHTTPAdapter( + proxies=proxies, ssl_context=ssl_context + ), + ) + nominatim.adapter.__dict__["session"] = async_get_clientsession(hass) + + _save_nom_data(hass, NomData({}, asyncio.Lock(), nominatim, INITIAL_WAIT)) + _LOGGER.debug("Initialized Nominatim data with cache size = %i", CACHE_SIZE) + return True + + +async def get_location(hass: HomeAssistant, lat: float, lng: float) -> Location | None: + """Get location data from given coordinates.""" + nom_data = NomData(**hass.data[NOMINATIM_DATA]) + + async def limit_rate() -> None: + """Hold the lock to limit calls to server.""" + await asyncio.sleep(nom_data.wait) + nom_data.lock.release() + + coordinates = f"{lat}, {lng}" + await nom_data.lock.acquire() + try: + return await nom_data.nominatim.reverse(coordinates) + except GeocoderRateLimited as exc: + if retry_after := exc.retry_after: + if retry_after > nom_data.wait: + _LOGGER.debug("Increasing wait time to %f sec", retry_after) + nom_data.wait = retry_after + _save_nom_data(hass, nom_data) + _LOGGER.warning("Request has been rate limited. Will retry") + return await get_location(hass, lat, lng) + except Exception as exc: # pylint: disable=broad-exception-caught + _LOGGER.error("While retrieving reverse geolocation data: %s", exc) + return None + finally: + hass.async_create_background_task(limit_rate(), "Limit nominatim query rate")