Skip to content

Commit

Permalink
feat(switch): Improved automatic rule discovery and removal (#31)
Browse files Browse the repository at this point in the history
- Now adds and removes switches properly on update coordinator refresh
- Reduced default update interval from 5 to 1 minute for more responsive updates
- Enhanced switch entity update mechanism with comprehensive data matching
- Implemented robust entity tracking to prevent duplicates
- Added dynamic entity removal for stale or non-existent items
- Simplified coordinator update listener registration
- Bumped version to 1.3.2 to reflect improvements
  • Loading branch information
sirkirby authored Feb 12, 2025
1 parent 6612099 commit c48ea56
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 62 deletions.
2 changes: 1 addition & 1 deletion custom_components/unifi_network_rules/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DEFAULT_RETRY_DELAY = 1

CONF_UPDATE_INTERVAL = "update_interval"
DEFAULT_UPDATE_INTERVAL = 5
DEFAULT_UPDATE_INTERVAL = 1 # Changed from 5 to 1 minute for more responsive updates
SESSION_TIMEOUT = 30

# API Endpoints
Expand Down
2 changes: 1 addition & 1 deletion custom_components/unifi_network_rules/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
"issue_tracker": "https://github.com/sirkirby/unifi-network-rules/issues",
"quality_scale": "custom",
"requirements": ["aiohttp"],
"version": "1.3.0"
"version": "1.3.2"
}
170 changes: 110 additions & 60 deletions custom_components/unifi_network_rules/switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,51 @@ def is_on(self) -> bool:
return self._pending_state
return bool(self._item_data.get('enabled', False))

@callback
def _handle_coordinator_update(self) -> None:
"""Handle updated data from the coordinator."""
if self.coordinator.data is None:
self._item_data = None
self.async_write_ha_state()
return

# Find the updated item data that matches this entity
if 'traffic_routes' in self.coordinator.data:
items = self.coordinator.data['traffic_routes']
item = next((i for i in items if i.get('_id') == self._item_id), None)
if item:
self._item_data = item
self.async_write_ha_state()
return

if 'firewall_policies' in self.coordinator.data:
items = self.coordinator.data['firewall_policies']
item = next((i for i in items if i.get('_id') == self._item_id), None)
if item:
self._item_data = item
self.async_write_ha_state()
return

if 'firewall_rules' in self.coordinator.data:
items = self.coordinator.data['firewall_rules'].get('data', [])
item = next((i for i in items if i.get('_id') == self._item_id), None)
if item:
self._item_data = item
self.async_write_ha_state()
return

if 'traffic_rules' in self.coordinator.data:
items = self.coordinator.data['traffic_rules']
item = next((i for i in items if i.get('_id') == self._item_id), None)
if item:
self._item_data = item
self.async_write_ha_state()
return

# If we get here, the item no longer exists in the coordinator data
self._item_data = None
self.async_write_ha_state()

async def _verify_state_change(self, target_state: bool, get_method, max_attempts: int = 3) -> bool:
"""Verify that the state change was successful."""
for attempt in range(max_attempts):
Expand Down Expand Up @@ -232,76 +277,81 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry, async_add_e
# Get entity registry
entity_registry = async_get(hass)

# Track existing entities to prevent duplicates
existing_entities = {}

@callback
def async_update_items():
"""Update entities."""
def async_update_items(now=None):
"""Update entities when coordinator data changes."""
current_entities = set()
new_entities = []
existing_ids = set()

# Track entities that should exist
valid_entity_ids = set()

# Handle traffic routes (available in both modes)
if coordinator.data and 'traffic_routes' in coordinator.data:
routes = coordinator.data['traffic_routes']
for route in routes:
entity_id = f"network_route_{route['_id']}"
valid_entity_ids.add(f"{DOMAIN}.{entity_id}")
if entity_id not in existing_ids:
new_entities.append(UDMTrafficRouteSwitch(coordinator, api, route))
existing_ids.add(entity_id)

if api.capabilities.zone_based_firewall:
# Handle firewall policies for zone-based firewall
if coordinator.data and 'firewall_policies' in coordinator.data:
policies = coordinator.data['firewall_policies']
for policy in policies:
if coordinator.data:
# Process traffic routes
if 'traffic_routes' in coordinator.data:
for route in coordinator.data['traffic_routes']:
entity_id = f"network_route_{route['_id']}"
current_entities.add(entity_id)
if entity_id not in existing_entities:
new_entity = UDMTrafficRouteSwitch(coordinator, api, route)
new_entities.append(new_entity)
existing_entities[entity_id] = new_entity

# Process firewall policies for zone-based firewall
if api.capabilities.zone_based_firewall and 'firewall_policies' in coordinator.data:
for policy in coordinator.data['firewall_policies']:
if not policy.get('predefined', False):
entity_id = f"network_policy_{policy['_id']}"
valid_entity_ids.add(f"{DOMAIN}.{entity_id}")
if entity_id not in existing_ids:
new_entities.append(UDMFirewallPolicySwitch(coordinator, api, policy, zones_data))
existing_ids.add(entity_id)

if api.capabilities.legacy_firewall:
# Handle legacy firewall rules
if coordinator.data and 'firewall_rules' in coordinator.data:
rules = coordinator.data['firewall_rules'].get('data', [])
for rule in rules:
entity_id = f"network_rule_firewall_{rule['_id']}"
valid_entity_ids.add(f"{DOMAIN}.{entity_id}")
if entity_id not in existing_ids:
new_entities.append(UDMLegacyFirewallRuleSwitch(coordinator, api, rule))
existing_ids.add(entity_id)

# Handle legacy traffic rules
if coordinator.data and 'traffic_rules' in coordinator.data:
rules = coordinator.data['traffic_rules']
for rule in rules:
entity_id = f"network_rule_traffic_{rule['_id']}"
valid_entity_ids.add(f"{DOMAIN}.{entity_id}")
if entity_id not in existing_ids:
new_entities.append(UDMLegacyTrafficRuleSwitch(coordinator, api, rule))
existing_ids.add(entity_id)

# Clean up old entities from the registry
_LOGGER.debug("Valid entity IDs: %s", valid_entity_ids)

all_entities = async_entries_for_config_entry(entity_registry, entry.entry_id)
for entity in all_entities:
if entity.entity_id not in valid_entity_ids:
_LOGGER.info("Removing old entity: %s", entity.entity_id)
entity_registry.async_remove(entity.entity_id)

current_entities.add(entity_id)
if entity_id not in existing_entities:
new_entity = UDMFirewallPolicySwitch(coordinator, api, policy, zones_data)
new_entities.append(new_entity)
existing_entities[entity_id] = new_entity

# Process legacy firewall rules
if api.capabilities.legacy_firewall:
if 'firewall_rules' in coordinator.data:
for rule in coordinator.data['firewall_rules'].get('data', []):
entity_id = f"network_rule_firewall_{rule['_id']}"
current_entities.add(entity_id)
if entity_id not in existing_entities:
new_entity = UDMLegacyFirewallRuleSwitch(coordinator, api, rule)
new_entities.append(new_entity)
existing_entities[entity_id] = new_entity

if 'traffic_rules' in coordinator.data:
for rule in coordinator.data['traffic_rules']:
entity_id = f"network_rule_traffic_{rule['_id']}"
current_entities.add(entity_id)
if entity_id not in existing_entities:
new_entity = UDMLegacyTrafficRuleSwitch(coordinator, api, rule)
new_entities.append(new_entity)
existing_entities[entity_id] = new_entity

# Remove entities that no longer exist
for entity_id in list(existing_entities.keys()):
if entity_id not in current_entities:
entity = existing_entities.pop(entity_id)
hass.async_create_task(async_remove_entity(hass, entity_registry, entity))

# Add new entities
if new_entities:
async_add_entities(new_entities)

# Initial entity setup
async def async_remove_entity(hass: HomeAssistant, registry: EntityRegistry, entity: SwitchEntity):
"""Remove entity from Home Assistant."""
if entity.entity_id:
registry.async_remove(entity.entity_id)
_LOGGER.info("Removed entity %s", entity.entity_id)

# Set up initial entities
async_update_items()

# Register callback for future updates
entry.async_on_unload(coordinator.async_add_listener(async_update_items))

# Register update listener
coordinator.async_add_listener(async_update_items)
entry.async_on_unload(lambda: coordinator.async_remove_listener(async_update_items))

return True

def async_entries_for_config_entry(registry: EntityRegistry, config_entry_id: str) -> List[Any]:
"""Get all entities for a config entry."""
Expand Down

0 comments on commit c48ea56

Please sign in to comment.