Skip to content

Commit

Permalink
refactor(udm_api): Improve UDM capability detection and legacy endpoi…
Browse files Browse the repository at this point in the history
…nt handling (#25)

- Enhanced capability detection with feature migration status check
- Added more robust method to detect zone-based firewall support
- Implemented fallback mechanism for legacy endpoint detection
- Improved logging and error handling for capability checks
- Added authentication checks for legacy firewall and traffic rule endpoints
  • Loading branch information
sirkirby authored Feb 8, 2025
1 parent 88133b0 commit fd55c47
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions custom_components/unifi_network_rules/udm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,30 @@ def __init__(self, host: str, username: str, password: str, max_retries: int = 3
async def detect_capabilities(self) -> bool:
"""Detect UDM capabilities by checking endpoints."""
try:
# Check zone-based firewall
zone_success, zone_data, _ = await self.get_firewall_zone_matrix()
self.capabilities.zone_based_firewall = zone_success and zone_data is not None

# Check legacy firewall
rules_success, rules_data, _ = await self.get_legacy_firewall_rules()
routes_success, routes_data, _ = await self.get_legacy_traffic_rules()

self.capabilities.legacy_firewall = (
(rules_success and isinstance(rules_data, dict)) or
(routes_success and isinstance(routes_data, list))
)
# Check feature migration status first
url = f"https://{self.host}/proxy/network/v2/api/site/default/site-feature-migration"
success, migrations, error = await self._make_authenticated_request('get', url)

_LOGGER.debug("Feature migration check: success=%s, data=%s, error=%s",
success, migrations, error)

if success and isinstance(migrations, list):
self.capabilities.zone_based_firewall = any(
m.get("feature") == "ZONE_BASED_FIREWALL"
for m in migrations
)
else:
# If migration check fails, check policies endpoint
success, policies, error = await self.get_firewall_policies()
_LOGGER.debug("Firewall policies check: success=%s, has_data=%s, error=%s",
success, bool(policies), error)
self.capabilities.zone_based_firewall = success

# If not zone-based, check legacy endpoints
if not self.capabilities.zone_based_firewall:
self.capabilities.legacy_firewall = await self._check_legacy_endpoints()
else:
self.capabilities.legacy_firewall = False

# Check traffic routes (available in both modes)
routes_success, routes_data, _ = await self.get_traffic_routes()
Expand All @@ -76,6 +88,23 @@ async def detect_capabilities(self) -> bool:
_LOGGER.error("Error detecting UDM capabilities: %s", str(e))
return False

async def _check_legacy_endpoints(self) -> bool:
"""Check if legacy endpoints return data."""
# Try legacy firewall rules endpoint
success, rules, error = await self.get_legacy_firewall_rules()
_LOGGER.debug("Legacy firewall check: success=%s, has_data=%s, error=%s",
success, bool(rules), error)

if success and rules:
return True

# Try legacy traffic rules endpoint as backup
success, rules, error = await self.get_legacy_traffic_rules()
_LOGGER.debug("Legacy traffic check: success=%s, has_data=%s, error=%s",
success, bool(rules), error)

return success and bool(rules)

async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create an aiohttp session."""
if self._session is None or self._session.closed:
Expand Down Expand Up @@ -349,6 +378,10 @@ async def get_firewall_zone_matrix(self) -> Tuple[bool, Optional[List[Dict[str,

async def get_legacy_firewall_rules(self) -> Tuple[bool, Optional[List[Dict[str, Any]]], Optional[str]]:
"""Fetch legacy firewall rules from the UDM."""
auth_success, auth_error = await self.ensure_authenticated()
if not auth_success:
return False, None, f"Authentication failed: {auth_error}"

url = f"https://{self.host}/proxy/network/api/s/default/rest/firewallrule"
success, response, error = await self._make_authenticated_request('get', url)

Expand All @@ -362,6 +395,10 @@ async def get_legacy_firewall_rules(self) -> Tuple[bool, Optional[List[Dict[str,

async def get_legacy_traffic_rules(self) -> Tuple[bool, Optional[List[Dict[str, Any]]], Optional[str]]:
"""Fetch legacy traffic rules from the UDM."""
auth_success, auth_error = await self.ensure_authenticated()
if not auth_success:
return False, None, f"Authentication failed: {auth_error}"

url = f"https://{self.host}/proxy/network/v2/api/site/default/trafficrules"
success, rules, error = await self._make_authenticated_request('get', url)

Expand Down

0 comments on commit fd55c47

Please sign in to comment.