Skip to content

Commit

Permalink
Implement Nominatim location caching
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbruckner committed Dec 20, 2023
1 parent d266c44 commit b5116c9
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 49 deletions.
33 changes: 31 additions & 2 deletions custom_components/entity_tz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Entity Time Zone Sensor."""
from __future__ import annotations

import logging

from async_lru import alru_cache
from geopy.location import Location

from homeassistant.components.device_tracker import DOMAIN as DT_DOMAIN
from homeassistant.components.person import DOMAIN as PERSON_DOMAIN
from homeassistant.config_entries import ConfigEntry
Expand All @@ -10,6 +15,7 @@
CONF_ENTITY_ID,
CONF_TIME_ZONE,
STATE_HOME,
STATE_UNAVAILABLE,
Platform,
)
from homeassistant.core import Event, HomeAssistant, State, callback
Expand All @@ -19,11 +25,15 @@
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util

from .const import DOMAIN
from .helpers import etz_data, get_loc, get_tz, init_etz_data, signal
from .const import DOMAIN, LOC_CACHE_PER_CONFIG
from .helpers import etz_data, get_tz, init_etz_data, signal
from .nominatim import get_location

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = [Platform.BINARY_SENSOR, Platform.SENSOR]
_LOGGER = logging.getLogger(__name__)

_get_location = alru_cache(0)(get_location)


async def async_setup(hass: HomeAssistant, _: ConfigType) -> bool:
Expand All @@ -32,12 +42,31 @@ async def async_setup(hass: HomeAssistant, _: ConfigType) -> bool:
return True


async def get_loc(hass: HomeAssistant, state: State | None) -> Location | str | None:
"""Get location data from entity state."""
if state is None:
return STATE_UNAVAILABLE
lat = state.attributes.get(ATTR_LATITUDE)
lng = state.attributes.get(ATTR_LONGITUDE)
if lat is None or lng is None:
return STATE_UNAVAILABLE

location = await _get_location(hass, f"{lat:.4f}, {lng:.4f}")
_LOGGER.debug("Location cache: %s", _get_location.cache_info())
return location


async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up config entry."""
etzd = etz_data(hass)
etzd.loc_users[entry.entry_id] = 0
etzd.tz_users[entry.entry_id] = 0

loc_cache_size = len(etzd.loc_users) * LOC_CACHE_PER_CONFIG
_get_location._LRUCacheWrapper__maxsize = max( # type: ignore[attr-defined] # pylint: disable=protected-access
_get_location._LRUCacheWrapper__maxsize, loc_cache_size # type: ignore[attr-defined] # pylint: disable=protected-access
)

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

if (tz_name := entry.data.get(CONF_TIME_ZONE)) is not None:
Expand Down
8 changes: 2 additions & 6 deletions custom_components/entity_tz/binary_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from homeassistant.const import CONF_TIME_ZONE
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import dt as dt_util

from .const import (
DIFF_COUNTRY_OFF_ICON,
DIFF_COUNTRY_ON_ICON,
DIFF_TIME_OFF_ICON,
DIFF_TIME_ON_ICON,
)
from .helpers import ETZEntity, ETZSource
from .helpers import ETZEntity, ETZSource, not_ha_tz


async def async_setup_entry(
Expand Down Expand Up @@ -67,10 +66,7 @@ async def async_update(self) -> None:
if not self._sources_valid:
return

n = dt_util.now()
self._attr_is_on = n.astimezone(self._entity_tz).replace(
tzinfo=None
) != n.replace(tzinfo=None)
self._attr_is_on = not_ha_tz(self._entity_tz)
if self.is_on:
self._attr_icon = DIFF_TIME_ON_ICON

Expand Down
2 changes: 2 additions & 0 deletions custom_components/entity_tz/const.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Constants for Entity Time Zone integration."""
DOMAIN = "entity_tz"

LOC_CACHE_PER_CONFIG = 8

ADDRESS_ICON = "mdi:map-marker"
COUNTRY_ICON = "mdi:web"
DIFF_COUNTRY_OFF_ICON = "mdi:home-city"
Expand Down
56 changes: 26 additions & 30 deletions custom_components/entity_tz/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from homeassistant.util import dt as dt_util

from .const import DOMAIN, SIG_ENTITY_CHANGED
from .nominatim import get_location, init_nominatim
from .nominatim import init_nominatim

_ALWAYS_DISABLED_ENTITIES = ("address", "country", "diff_country", "diff_time")

Expand All @@ -58,11 +58,35 @@ class ETZData:
zones: list[str]


def not_ha_tz(tz: tzinfo | str | None) -> bool:
"""Return if time zone is effectively different than HA's time zone."""
if not isinstance(tz, tzinfo):
return False
n = dt_util.now()
return n.astimezone(tz).replace(tzinfo=None) != n.replace(tzinfo=None)


def etz_data(hass: HomeAssistant) -> ETZData:
"""Return Entity Time Zone integration data."""
return cast(ETZData, hass.data[DOMAIN])


def get_tz(hass: HomeAssistant, state: State | None) -> tzinfo | str | None:
"""Get time zone from entity state."""
if not state:
return STATE_UNAVAILABLE
if state.domain in (PERSON_DOMAIN, DT_DOMAIN) and state.state == STATE_HOME:
return dt_util.DEFAULT_TIME_ZONE
lat = state.attributes.get(ATTR_LATITUDE)
lng = state.attributes.get(ATTR_LONGITUDE)
if lat is None or lng is None:
return STATE_UNAVAILABLE
tz_name = etz_data(hass).tzf.timezone_at(lat=lat, lng=lng)
if tz_name is None:
return None
return dt_util.get_time_zone(tz_name)


async def init_etz_data(hass: HomeAssistant) -> None:
"""Initialize integration's data."""
if DOMAIN in hass.data:
Expand All @@ -88,7 +112,7 @@ def update_zones(_: Event | None = None) -> None:
"""Update list of zones to use."""
zones = []
for state in hass.states.async_all(ZONE_DOMAIN):
if get_tz(hass, state) != dt_util.DEFAULT_TIME_ZONE:
if not_ha_tz(get_tz(hass, state)):
zones.append(state.entity_id)
etzd.zones = zones

Expand All @@ -102,34 +126,6 @@ def zones_filter(event: Event) -> bool:
hass.bus.async_listen(EVENT_STATE_CHANGED, update_zones, zones_filter)


def get_tz(hass: HomeAssistant, state: State | None) -> tzinfo | str | None:
"""Get time zone from entity state."""
if not state:
return STATE_UNAVAILABLE
if state.domain in (PERSON_DOMAIN, DT_DOMAIN) and state.state == STATE_HOME:
return dt_util.DEFAULT_TIME_ZONE
lat = state.attributes.get(ATTR_LATITUDE)
lng = state.attributes.get(ATTR_LONGITUDE)
if lat is None or lng is None:
return STATE_UNAVAILABLE
tz_name = etz_data(hass).tzf.timezone_at(lat=lat, lng=lng)
if tz_name is None:
return None
return dt_util.get_time_zone(tz_name)


async def get_loc(hass: HomeAssistant, state: State | None) -> Location | str | None:
"""Get location data from entity state."""
if state is None:
return STATE_UNAVAILABLE
lat = state.attributes.get(ATTR_LATITUDE)
lng = state.attributes.get(ATTR_LONGITUDE)
if lat is None or lng is None:
return STATE_UNAVAILABLE

return await get_location(hass, lat, lng)


def signal(entry: ConfigEntry) -> str:
"""Return signal name derived from config entry."""
return f"{SIG_ENTITY_CHANGED}-{entry.entry_id}"
Expand Down
2 changes: 1 addition & 1 deletion custom_components/entity_tz/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
"iot_class": "calculated",
"issue_tracker": "https://github.com/pnbruckner/ha-entity-tz/issues",
"loggers": ["geopy"],
"requirements": ["timezonefinder==5.2.0"],
"requirements": ["async-lru==2.0.4", "timezonefinder==5.2.0"],
"version": "1.0.0b9"
}
19 changes: 9 additions & 10 deletions custom_components/entity_tz/nominatim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging

from geopy.adapters import AioHTTPAdapter
from geopy.exc import GeocoderRateLimited
from geopy.exc import GeocoderRateLimited, GeopyError
from geopy.geocoders import Nominatim
from geopy.location import Location

Expand All @@ -18,7 +18,6 @@
NOMINATIM_DATA = "sharable_nominatim"
TIMEOUT = 10
INITIAL_WAIT = 1.5
CACHE_SIZE = 25

_LOGGER = logging.getLogger(__name__)

Expand All @@ -27,7 +26,6 @@
class NomData:
"""Nomination data."""

cache: dict[str, Location | None]
lock: asyncio.Lock
nominatim: Nominatim
wait: float
Expand Down Expand Up @@ -63,21 +61,22 @@ def init_nominatim(hass: HomeAssistant) -> bool:
)
nominatim.adapter.__dict__["session"] = async_get_clientsession(hass)

_save_nom_data(hass, NomData({}, asyncio.Lock(), nominatim, INITIAL_WAIT))
_LOGGER.debug("Initialized Nominatim data with cache size = %i", CACHE_SIZE)
_save_nom_data(hass, NomData(asyncio.Lock(), nominatim, INITIAL_WAIT))
return True


async def get_location(hass: HomeAssistant, lat: float, lng: float) -> Location | None:
"""Get location data from given coordinates."""
async def get_location(hass: HomeAssistant, coordinates: str) -> Location | None:
"""Get location data from given coordinates.
coordinates: string, formatted as 'lat, lng'
"""
nom_data = NomData(**hass.data[NOMINATIM_DATA])

async def limit_rate() -> None:
"""Hold the lock to limit calls to server."""
await asyncio.sleep(nom_data.wait)
nom_data.lock.release()

coordinates = f"{lat}, {lng}"
await nom_data.lock.acquire()
try:
return await nom_data.nominatim.reverse(coordinates)
Expand All @@ -88,8 +87,8 @@ async def limit_rate() -> None:
nom_data.wait = retry_after
_save_nom_data(hass, nom_data)
_LOGGER.warning("Request has been rate limited. Will retry")
return await get_location(hass, lat, lng)
except Exception as exc: # pylint: disable=broad-exception-caught
return await get_location(hass, coordinates)
except GeopyError as exc:
_LOGGER.error("While retrieving reverse geolocation data: %s", exc)
return None
finally:
Expand Down

0 comments on commit b5116c9

Please sign in to comment.