Skip to content

Commit

Permalink
Move Nominatim code into separate, sharable module
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbruckner committed Dec 19, 2023
1 parent cdf5397 commit d266c44
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 73 deletions.
18 changes: 10 additions & 8 deletions custom_components/entity_tz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from homeassistant.util import dt as dt_util

from .const import DOMAIN
from .helpers import etz_data, get_location, get_tz, init_etz_data, signal
from .helpers import etz_data, get_loc, get_tz, init_etz_data, signal

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = [Platform.BINARY_SENSOR, Platform.SENSOR]
Expand All @@ -34,8 +34,9 @@ async def async_setup(hass: HomeAssistant, _: ConfigType) -> bool:

async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up config entry."""
etz_data(hass).loc_users[entry.entry_id] = 0
etz_data(hass).tz_users[entry.entry_id] = 0
etzd = etz_data(hass)
etzd.loc_users[entry.entry_id] = 0
etzd.tz_users[entry.entry_id] = 0

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

Expand Down Expand Up @@ -63,11 +64,11 @@ async def update_from_entity(event: Event | None = None) -> None:
or new_state.attributes.get(ATTR_LONGITUDE)
!= old_state.attributes.get(ATTR_LONGITUDE)
):
if etz_data(hass).loc_users[entry.entry_id]:
entity_loc = await get_location(hass, new_state)
if etzd.loc_available and etzd.loc_users[entry.entry_id]:
entity_loc = await get_loc(hass, new_state)
else:
entity_loc = None
if etz_data(hass).tz_users[entry.entry_id]:
if etzd.tz_users[entry.entry_id]:
entity_tz = get_tz(hass, new_state)
else:
entity_tz = None
Expand All @@ -93,8 +94,9 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
res = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

loc_users = etz_data(hass).loc_users.pop(entry.entry_id)
tz_uers = etz_data(hass).tz_users.pop(entry.entry_id)
etzd = etz_data(hass)
loc_users = etzd.loc_users.pop(entry.entry_id)
tz_uers = etzd.tz_users.pop(entry.entry_id)
assert not loc_users
assert not tz_uers
return res
3 changes: 0 additions & 3 deletions custom_components/entity_tz/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
LOCAL_TIME_ICON = "mdi:account-clock"
TIME_ZONE_ICON = "mdi:map-clock"

NOM_TIMEOUT = 10
NOM_WAIT = 1.5

ATTR_COUNTRY_CODE = "country_code"
ATTR_UTC_OFFSET = "utc_offset"

Expand Down
88 changes: 26 additions & 62 deletions custom_components/entity_tz/helpers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
"""Entity Time Zone Sensor Helpers."""
from __future__ import annotations

import asyncio
from collections.abc import Container, Mapping
from dataclasses import dataclass
from datetime import tzinfo
from enum import Enum, auto
import logging
from typing import Any, cast
from zoneinfo import available_timezones

from geopy.adapters import AioHTTPAdapter
from geopy.geocoders import Nominatim
from geopy.location import Location
from timezonefinder import TimezoneFinder

Expand All @@ -31,10 +27,6 @@
STATE_UNAVAILABLE,
)
from homeassistant.core import Event, HomeAssistant, State, callback, split_entity_id
from homeassistant.helpers.aiohttp_client import (
SERVER_SOFTWARE,
async_get_clientsession,
)
from homeassistant.helpers.device_registry import DeviceEntryType

# Device Info moved to device_registry in 2023.9
Expand All @@ -48,20 +40,20 @@
from homeassistant.helpers.event import async_track_time_change
from homeassistant.util import dt as dt_util

from .const import DOMAIN, NOM_TIMEOUT, NOM_WAIT, SIG_ENTITY_CHANGED
from .const import DOMAIN, SIG_ENTITY_CHANGED
from .nominatim import get_location, init_nominatim

_LOGGER = logging.getLogger(__name__)
_ALWAYS_DISABLED_ENTITIES = ("address", "country", "diff_country", "diff_time")


@dataclass(init=False)
class ETZData:
"""Entity Time Zone integration data."""

tzs: set[str]
nominatim: Nominatim
loc_available: bool
loc_users: dict[str, int]
query_lock: asyncio.Lock
tzf: TimezoneFinder
tzs: set[str]
tz_users: dict[str, int]
zones: list[str]

Expand All @@ -75,29 +67,19 @@ async def init_etz_data(hass: HomeAssistant) -> None:
"""Initialize integration's data."""
if DOMAIN in hass.data:
return
hass.data[DOMAIN] = ETZData()
etz_data(hass).loc_users = {}
etz_data(hass).tz_users = {}

nominatim = Nominatim(
user_agent=SERVER_SOFTWARE,
timeout=NOM_TIMEOUT,
adapter_factory=lambda proxies, ssl_context: AioHTTPAdapter(
proxies=proxies, ssl_context=ssl_context
),
)
nominatim.adapter.__dict__["session"] = async_get_clientsession(hass)
etz_data(hass).nominatim = nominatim
etz_data(hass).query_lock = asyncio.Lock()
hass.data[DOMAIN] = etzd = ETZData()
etzd.loc_available = init_nominatim(hass)
etzd.loc_users = {}
etzd.tz_users = {}

def init_tz_data() -> None:
"""Initialize time zone data."""

# This must be done in an executor since zoneinfo.available_timezones and the
# TimezoneFinder constructor both do file I/O.
"""Initialize time zone data.
etz_data(hass).tzs = available_timezones()
etz_data(hass).tzf = TimezoneFinder()
This must be done in an executor since TimezoneFinder constructor and
zoneinfo.available_timezones both do file I/O.
"""
etzd.tzf = TimezoneFinder()
etzd.tzs = available_timezones()

await hass.async_add_executor_job(init_tz_data)

Expand All @@ -108,7 +90,7 @@ def update_zones(_: Event | None = None) -> None:
for state in hass.states.async_all(ZONE_DOMAIN):
if get_tz(hass, state) != dt_util.DEFAULT_TIME_ZONE:
zones.append(state.entity_id)
etz_data(hass).zones = zones
etzd.zones = zones

@callback
def zones_filter(event: Event) -> bool:
Expand All @@ -121,7 +103,7 @@ def zones_filter(event: Event) -> bool:


def get_tz(hass: HomeAssistant, state: State | None) -> tzinfo | str | None:
"""Get time zone from latitude & longitude from state."""
"""Get time zone from entity state."""
if not state:
return STATE_UNAVAILABLE
if state.domain in (PERSON_DOMAIN, DT_DOMAIN) and state.state == STATE_HOME:
Expand All @@ -136,43 +118,23 @@ def get_tz(hass: HomeAssistant, state: State | None) -> tzinfo | str | None:
return dt_util.get_time_zone(tz_name)


async def get_location(
hass: HomeAssistant, state: State | None
) -> Location | str | None:
"""Get address from latitude & longitude."""
async def get_loc(hass: HomeAssistant, state: State | None) -> Location | str | None:
"""Get location data from entity state."""
if state is None:
return STATE_UNAVAILABLE
lat = state.attributes.get(ATTR_LATITUDE)
lng = state.attributes.get(ATTR_LONGITUDE)
if lat is None or lng is None:
return STATE_UNAVAILABLE

lock = etz_data(hass).query_lock

async def limit_rate() -> None:
"""Hold the lock to limit calls to server."""
await asyncio.sleep(NOM_WAIT)
lock.release()

coordinates = f"{lat}, {lng}"
await lock.acquire()
try:
return await etz_data(hass).nominatim.reverse(coordinates)
except Exception as exc: # pylint: disable=broad-exception-caught
_LOGGER.error("While getting address & country code: %s", exc)
return None
finally:
hass.async_create_background_task(limit_rate(), "Limit nominatim query rate")
return await get_location(hass, lat, lng)


def signal(entry: ConfigEntry) -> str:
"""Return signal name derived from config entry."""
return f"{SIG_ENTITY_CHANGED}-{entry.entry_id}"


_ALWAYS_DISABLED_ENTITIES = ("address", "country", "diff_country", "diff_time")


def _enable_entity(key: str, entry_data: Mapping[str, Any]) -> bool:
"""Determine if entity should be enabled by default."""
if key in _ALWAYS_DISABLED_ENTITIES:
Expand Down Expand Up @@ -252,10 +214,11 @@ def entity_changed(
loc_user = ETZSource.LOC in self._sources
tz_user = ETZSource.TZ in self._sources
if loc_user or tz_user:
etzd = etz_data(self.hass)
if loc_user:
etz_data(self.hass).loc_users[config_entry.entry_id] += 1
etzd.loc_users[config_entry.entry_id] += 1
if tz_user:
etz_data(self.hass).tz_users[config_entry.entry_id] += 1
etzd.tz_users[config_entry.entry_id] += 1
self.async_on_remove(
async_dispatcher_connect(
self.hass, signal(config_entry), entity_changed
Expand All @@ -278,10 +241,11 @@ def update(_: Any) -> None:
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
config_entry = cast(ConfigEntry, self.platform.config_entry)
etzd = etz_data(self.hass)
if ETZSource.LOC in self._sources:
etz_data(self.hass).loc_users[config_entry.entry_id] -= 1
etzd.loc_users[config_entry.entry_id] -= 1
if ETZSource.TZ in self._sources:
etz_data(self.hass).tz_users[config_entry.entry_id] -= 1
etzd.tz_users[config_entry.entry_id] -= 1

@property
def _sources_valid(self) -> bool:
Expand Down
1 change: 1 addition & 0 deletions custom_components/entity_tz/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"documentation": "https://github.com/pnbruckner/ha-entity-tz/blob/master/README.md",
"iot_class": "calculated",
"issue_tracker": "https://github.com/pnbruckner/ha-entity-tz/issues",
"loggers": ["geopy"],
"requirements": ["timezonefinder==5.2.0"],
"version": "1.0.0b9"
}
96 changes: 96 additions & 0 deletions custom_components/entity_tz/nominatim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Nominatim from geopy helper."""

import asyncio
from dataclasses import dataclass, fields
import logging

from geopy.adapters import AioHTTPAdapter
from geopy.exc import GeocoderRateLimited
from geopy.geocoders import Nominatim
from geopy.location import Location

from homeassistant.core import HomeAssistant
from homeassistant.helpers.aiohttp_client import (
SERVER_SOFTWARE,
async_get_clientsession,
)

NOMINATIM_DATA = "sharable_nominatim"
TIMEOUT = 10
INITIAL_WAIT = 1.5
CACHE_SIZE = 25

_LOGGER = logging.getLogger(__name__)


@dataclass
class NomData:
"""Nomination data."""

cache: dict[str, Location | None]
lock: asyncio.Lock
nominatim: Nominatim
wait: float


def _save_nom_data(hass: HomeAssistant, nom_data: NomData) -> None:
"""Save nominatim data to hass.data."""
hass.data[NOMINATIM_DATA] = {
field.name: getattr(nom_data, field.name) for field in fields(nom_data)
}


def init_nominatim(hass: HomeAssistant) -> bool:
"""Initialize sharable Nominatim object."""
if NOMINATIM_DATA in hass.data:
nom_data = hass.data[NOMINATIM_DATA]
try:
for field in fields(NomData):
if not isinstance(nom_data[field.name], field.type):
raise TypeError
except (KeyError, TypeError):
msg = f"Unexpected data in hass.data[{NOMINATIM_DATA!r}]: {nom_data}"
_LOGGER.error(msg)
return False
return True

nominatim = Nominatim(
user_agent=SERVER_SOFTWARE,
timeout=TIMEOUT,
adapter_factory=lambda proxies, ssl_context: AioHTTPAdapter(
proxies=proxies, ssl_context=ssl_context
),
)
nominatim.adapter.__dict__["session"] = async_get_clientsession(hass)

_save_nom_data(hass, NomData({}, asyncio.Lock(), nominatim, INITIAL_WAIT))
_LOGGER.debug("Initialized Nominatim data with cache size = %i", CACHE_SIZE)
return True


async def get_location(hass: HomeAssistant, lat: float, lng: float) -> Location | None:
"""Get location data from given coordinates."""
nom_data = NomData(**hass.data[NOMINATIM_DATA])

async def limit_rate() -> None:
"""Hold the lock to limit calls to server."""
await asyncio.sleep(nom_data.wait)
nom_data.lock.release()

coordinates = f"{lat}, {lng}"
await nom_data.lock.acquire()
try:
return await nom_data.nominatim.reverse(coordinates)
except GeocoderRateLimited as exc:
if retry_after := exc.retry_after:
if retry_after > nom_data.wait:
_LOGGER.debug("Increasing wait time to %f sec", retry_after)
nom_data.wait = retry_after
_save_nom_data(hass, nom_data)
_LOGGER.warning("Request has been rate limited. Will retry")
return await get_location(hass, lat, lng)
except Exception as exc: # pylint: disable=broad-exception-caught
_LOGGER.error("While retrieving reverse geolocation data: %s", exc)
return None
finally:
hass.async_create_background_task(limit_rate(), "Limit nominatim query rate")

0 comments on commit d266c44

Please sign in to comment.