Skip to content

Commit

Permalink
feat(api): Enhanced authentication and error handling (#17)
Browse files Browse the repository at this point in the history
- Improved authentication flow with more robust token and session management
- Added quick authentication check method
- Enhanced error handling and logging for API requests
- Implemented better rate limiting and request interval management
- Added more debug logging for authentication and request processes
- Refactored API methods to use a more consistent request pattern
  • Loading branch information
sirkirby authored Feb 5, 2025
1 parent f4e7811 commit c0790ad
Show file tree
Hide file tree
Showing 3 changed files with 471 additions and 302 deletions.
171 changes: 113 additions & 58 deletions custom_components/unifi_network_rules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Support for UniFi Network Rules."""
import asyncio
import logging
from datetime import timedelta

from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_USERNAME, CONF_PASSWORD
Expand All @@ -9,7 +11,15 @@
from homeassistant.helpers import config_validation as cv
from homeassistant.exceptions import ConfigEntryNotReady

from .const import DOMAIN, CONF_MAX_RETRIES, CONF_RETRY_DELAY, DEFAULT_MAX_RETRIES, DEFAULT_RETRY_DELAY, CONF_UPDATE_INTERVAL, DEFAULT_UPDATE_INTERVAL
from .const import (
DOMAIN,
CONF_MAX_RETRIES,
CONF_RETRY_DELAY,
DEFAULT_MAX_RETRIES,
DEFAULT_RETRY_DELAY,
CONF_UPDATE_INTERVAL,
DEFAULT_UPDATE_INTERVAL
)
from .udm_api import UDMAPI

_LOGGER = logging.getLogger(__name__)
Expand All @@ -19,71 +29,116 @@
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)

async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the UDM Rule Manager component."""
hass.data.setdefault(DOMAIN, {})
return True
"""Set up the UniFi Network Rules component."""
hass.data.setdefault(DOMAIN, {})
return True

async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up UDM Rule Manager from a config entry."""
host = entry.data[CONF_HOST]
username = entry.data[CONF_USERNAME]
password = entry.data[CONF_PASSWORD]
update_interval = entry.data.get(CONF_UPDATE_INTERVAL, DEFAULT_UPDATE_INTERVAL)
max_retries = entry.data.get(CONF_MAX_RETRIES, DEFAULT_MAX_RETRIES)
retry_delay = entry.data.get(CONF_RETRY_DELAY, DEFAULT_RETRY_DELAY)

api = UDMAPI(host, username, password, max_retries=max_retries, retry_delay=retry_delay)

# Test the connection
success, error_message = await api.login()
if not success:
raise ConfigEntryNotReady(f"Failed to connect to UDM: {error_message}")

async def async_update_data():
"""Fetch data from API."""
"""Set up UniFi Network Rules from a config entry."""
_LOGGER.debug("Setting up UniFi Network Rules config entry")

host = entry.data[CONF_HOST]
username = entry.data[CONF_USERNAME]
password = entry.data[CONF_PASSWORD]
update_interval = entry.data.get(CONF_UPDATE_INTERVAL, DEFAULT_UPDATE_INTERVAL)
max_retries = entry.data.get(CONF_MAX_RETRIES, DEFAULT_MAX_RETRIES)
retry_delay = entry.data.get(CONF_RETRY_DELAY, DEFAULT_RETRY_DELAY)

api = UDMAPI(host, username, password, max_retries=max_retries, retry_delay=retry_delay)

# Test the connection with quick auth check
try:
policies_success, policies, policies_error = await api.get_firewall_policies()
routes_success, traffic_routes, routes_error = await api.get_traffic_routes()

if not policies_success:
raise Exception(f"Failed to fetch firewall policies: {policies_error}")
if not routes_success:
raise Exception(f"Failed to fetch traffic routes: {routes_error}")

return {
"firewall_policies": policies,
"traffic_routes": traffic_routes
}
success, error = await api.quick_auth_check()
if not success:
await api.cleanup()
_LOGGER.error(f"Failed to connect to UDM: {error}")
raise ConfigEntryNotReady(f"Failed to connect to UDM: {error}")
except Exception as e:
_LOGGER.error(f"Error updating data: {str(e)}")
raise

coordinator = DataUpdateCoordinator(
hass,
_LOGGER,
name="udm_rule_manager",
update_method=async_update_data,
update_interval=timedelta(minutes=update_interval),
)

# Fetch initial data
await coordinator.async_config_entry_first_refresh()

hass.data[DOMAIN][entry.entry_id] = {
'api': api,
'coordinator': coordinator
}

await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

return True
await api.cleanup()
_LOGGER.exception("Error during setup")
raise ConfigEntryNotReady(f"Setup failed: {str(e)}") from e

async def async_update_data():
"""Fetch data from API."""
try:
# Add delay between requests
policies_success, policies, policies_error = await api.get_firewall_policies()
if not policies_success:
raise Exception(f"Failed to fetch firewall policies: {policies_error}")

await asyncio.sleep(2) # Wait between requests

routes_success, traffic_routes, routes_error = await api.get_traffic_routes()
if not routes_success:
raise Exception(f"Failed to fetch traffic routes: {routes_error}")

return {
"firewall_policies": policies,
"traffic_routes": traffic_routes
}
except Exception as e:
_LOGGER.error(f"Error updating data: {str(e)}")
raise

coordinator = DataUpdateCoordinator(
hass,
_LOGGER,
name="udm_rule_manager",
update_method=async_update_data,
update_interval=timedelta(minutes=max(update_interval, 15)),
)

# Store api and coordinator
hass.data[DOMAIN][entry.entry_id] = {
'api': api,
'coordinator': coordinator,
}

# Register cleanup for config entry
entry.async_on_unload(
lambda: hass.async_create_task(cleanup_api(hass, entry))
)

# Set up platforms
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

# Initial data fetch
try:
_LOGGER.debug("Performing initial data refresh")
await coordinator.async_config_entry_first_refresh()
except Exception as err:
_LOGGER.error(f"Initial data refresh failed: {err}")
await cleanup_api(hass, entry)
raise ConfigEntryNotReady from err

return True

async def cleanup_api(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Clean up API resources."""
if DOMAIN in hass.data and entry.entry_id in hass.data[DOMAIN]:
api = hass.data[DOMAIN][entry.entry_id].get('api')
if api is not None:
try:
await api.cleanup()
except Exception as e:
_LOGGER.error(f"Error during API cleanup: {e}")

async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
_LOGGER.debug("Unloading UniFi Network Rules config entry")

# Unload platforms
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

if unload_ok:
api = hass.data[DOMAIN][entry.entry_id]['api']
await api.cleanup() # Cleanup API session
# Clean up API
await cleanup_api(hass, entry)
# Remove entry data
hass.data[DOMAIN].pop(entry.entry_id)

return unload_ok
return unload_ok

async def async_reload_entry(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Reload config entry."""
await async_unload_entry(hass, entry)
await async_setup_entry(hass, entry)
144 changes: 86 additions & 58 deletions custom_components/unifi_network_rules/switch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Support for UniFi Network Rules switches."""
import logging
from homeassistant.components.switch import SwitchEntity
from homeassistant.core import callback
Expand All @@ -6,24 +7,37 @@
from homeassistant.config_entries import ConfigEntry
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from homeassistant.helpers.event import async_call_later

from .const import DOMAIN

_LOGGER = logging.getLogger(__name__)

async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback):
"""Set up the UDM Rule Manager switches."""
"""Set up the UniFi Network Rules switches."""
coordinator = hass.data[DOMAIN][entry.entry_id]['coordinator']
api = hass.data[DOMAIN][entry.entry_id]['api']

_LOGGER.debug("Setting up switches with coordinator data: %s", coordinator.data)

# Check if we have data
if not coordinator.data:
_LOGGER.debug("No data available yet, waiting for coordinator refresh")
await coordinator.async_refresh()

entities = []

for policy in coordinator.data.get('firewall_policies', []):
# Add firewall policy switches
policies = coordinator.data.get('firewall_policies', []) if coordinator.data else []
for policy in policies:
entities.append(UDMFirewallPolicySwitch(coordinator, api, policy))

for route in coordinator.data.get('traffic_routes', []):
# Add traffic route switches
routes = coordinator.data.get('traffic_routes', []) if coordinator.data else []
for route in routes:
entities.append(UDMTrafficRouteSwitch(coordinator, api, route))

_LOGGER.debug("Adding %d entities", len(entities))
async_add_entities(entities, True)

class UDMTrafficRouteSwitch(CoordinatorEntity, SwitchEntity):
Expand All @@ -37,6 +51,11 @@ def __init__(self, coordinator, api, route):
self._attr_name = f"Traffic Route: {route.get('description', 'Unnamed')}"
self._route_id = route['_id']

@property
def available(self) -> bool:
"""Return if entity is available."""
return self.coordinator.last_update_success and self._get_route() is not None

@property
def is_on(self):
"""Return true if the switch is on."""
Expand Down Expand Up @@ -64,6 +83,8 @@ async def _toggle(self, new_state):

def _get_route(self):
"""Get the current route from the coordinator data."""
if not self.coordinator.data:
return None
routes = self.coordinator.data.get('traffic_routes', [])
return next((r for r in routes if r['_id'] == self._route_id), None)

Expand Down Expand Up @@ -98,62 +119,69 @@ def extra_state_attributes(self):
return attributes

class UDMFirewallPolicySwitch(CoordinatorEntity, SwitchEntity):
"""Representation of a UDM Firewall Policy Switch."""

def __init__(self, coordinator, api, policy):
"""Initialize the UDM Firewall Policy Switch."""
super().__init__(coordinator)
self._api = api
self._attr_unique_id = f"firewall_policy_{policy['_id']}"
self._attr_name = f"Firewall Policy: {policy.get('name', 'Unnamed')}"
self._policy_id = policy['_id']

@property
def is_on(self):
"""Return true if the switch is on."""
policy = self._get_policy()
return policy['enabled'] if policy else False

async def async_turn_on(self, **kwargs):
"""Turn the switch on."""
await self._toggle(True)

async def async_turn_off(self, **kwargs):
"""Turn the switch off."""
await self._toggle(False)

async def _toggle(self, new_state):
"""Toggle the policy state."""
try:
success, error_message = await self._api.toggle_firewall_policy(self._policy_id, new_state)
if success:
await self.coordinator.async_request_refresh()
else:
raise HomeAssistantError(f"Failed to toggle firewall policy: {error_message}")
except Exception as e:
raise HomeAssistantError(f"Error toggling firewall policy: {str(e)}")

def _get_policy(self):
"""Representation of a UDM Firewall Policy Switch."""

def __init__(self, coordinator, api, policy):
"""Initialize the UDM Firewall Policy Switch."""
super().__init__(coordinator)
self._api = api
self._attr_unique_id = f"firewall_policy_{policy['_id']}"
self._attr_name = f"Firewall Policy: {policy.get('name', 'Unnamed')}"
self._policy_id = policy['_id']

@property
def available(self) -> bool:
"""Return if entity is available."""
return self.coordinator.last_update_success and self._get_policy() is not None

@property
def is_on(self):
"""Return true if the switch is on."""
policy = self._get_policy()
return policy['enabled'] if policy else False

async def async_turn_on(self, **kwargs):
"""Turn the switch on."""
await self._toggle(True)

async def async_turn_off(self, **kwargs):
"""Turn the switch off."""
await self._toggle(False)

async def _toggle(self, new_state):
"""Toggle the policy state."""
try:
success, error_message = await self._api.toggle_firewall_policy(self._policy_id, new_state)
if success:
await self.coordinator.async_request_refresh()
else:
raise HomeAssistantError(f"Failed to toggle firewall policy: {error_message}")
except Exception as e:
raise HomeAssistantError(f"Error toggling firewall policy: {str(e)}")

def _get_policy(self):
"""Get the current policy from the coordinator data."""
if not self.coordinator.data:
return None
policies = self.coordinator.data.get('firewall_policies', [])
return next((p for p in policies if p['_id'] == self._policy_id), None)

@property
def extra_state_attributes(self):
"""Return additional state attributes."""
policy = self._get_policy()
if not policy:
return {}

return {
"name": policy.get("name", ""),
"action": policy.get("action", ""),
"predefined": policy.get("predefined", False),
"protocol": policy.get("protocol", ""),
"schedule_mode": policy.get("schedule", {}).get("mode", ""),
"source_zone": policy.get("source", {}).get("zone_id", ""),
"destination_zone": policy.get("destination", {}).get("zone_id", ""),
"index": policy.get("index", 0),
"matching_target": policy.get("source", {}).get("matching_target", ""),
"ip_version": policy.get("ip_version", "")
}
@property
def extra_state_attributes(self):
"""Return additional state attributes."""
policy = self._get_policy()
if not policy:
return {}

return {
"name": policy.get("name", ""),
"action": policy.get("action", ""),
"predefined": policy.get("predefined", False),
"protocol": policy.get("protocol", ""),
"schedule_mode": policy.get("schedule", {}).get("mode", ""),
"source_zone": policy.get("source", {}).get("zone_id", ""),
"destination_zone": policy.get("destination", {}).get("zone_id", ""),
"index": policy.get("index", 0),
"matching_target": policy.get("source", {}).get("matching_target", ""),
"ip_version": policy.get("ip_version", "")
}
Loading

0 comments on commit c0790ad

Please sign in to comment.