Skip to content

Commit

Permalink
Refactor targets for zwave_js services (#115734)
Browse files Browse the repository at this point in the history
* Let labels be used as targets for zwave_js services

* add coverage

* Fix test bug and switch from targets to fields

* Remove label addition

* Remove labels from service descriptions

* Remove labels from strings

* More changes
  • Loading branch information
raman325 authored Aug 22, 2024
1 parent 281a9f0 commit fc1ed7d
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 86 deletions.
26 changes: 12 additions & 14 deletions homeassistant/components/zwave_js/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,20 +343,18 @@ def async_get_nodes_from_area_id(
}
)
# Add devices in an area that are Z-Wave JS devices
for device in dr.async_entries_for_area(dev_reg, area_id):
if next(
(
config_entry_id
for config_entry_id in device.config_entries
if cast(
ConfigEntry,
hass.config_entries.async_get_entry(config_entry_id),
).domain
== DOMAIN
),
None,
):
nodes.add(async_get_node_from_device_id(hass, device.id, dev_reg))
nodes.update(
async_get_node_from_device_id(hass, device.id, dev_reg)
for device in dr.async_entries_for_area(dev_reg, area_id)
if any(
cast(
ConfigEntry,
hass.config_entries.async_get_entry(config_entry_id),
).domain
== DOMAIN
for config_entry_id in device.config_entries
)
)

return nodes

Expand Down
64 changes: 13 additions & 51 deletions homeassistant/components/zwave_js/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@

type _NodeOrEndpointType = ZwaveNode | Endpoint

TARGET_VALIDATORS = {
vol.Optional(ATTR_AREA_ID): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(ATTR_DEVICE_ID): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
}


def parameter_name_does_not_need_bitmask(
val: dict[str, int | str | list[str]],
Expand Down Expand Up @@ -261,13 +267,7 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
**TARGET_VALIDATORS,
vol.Optional(const.ATTR_ENDPOINT, default=0): vol.Coerce(int),
vol.Required(const.ATTR_CONFIG_PARAMETER): vol.Any(
vol.Coerce(int), cv.string
Expand Down Expand Up @@ -305,13 +305,7 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
**TARGET_VALIDATORS,
vol.Optional(const.ATTR_ENDPOINT, default=0): vol.Coerce(int),
vol.Required(const.ATTR_CONFIG_PARAMETER): vol.Coerce(int),
vol.Required(const.ATTR_CONFIG_VALUE): vol.Any(
Expand Down Expand Up @@ -356,13 +350,7 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
**TARGET_VALIDATORS,
vol.Required(const.ATTR_COMMAND_CLASS): vol.Coerce(int),
vol.Required(const.ATTR_PROPERTY): vol.Any(
vol.Coerce(int), str
Expand Down Expand Up @@ -391,13 +379,7 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
**TARGET_VALIDATORS,
vol.Optional(const.ATTR_BROADCAST, default=False): cv.boolean,
vol.Required(const.ATTR_COMMAND_CLASS): vol.Coerce(int),
vol.Required(const.ATTR_PROPERTY): vol.Any(
Expand Down Expand Up @@ -428,15 +410,7 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
self.async_ping,
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
},
TARGET_VALIDATORS,
cv.has_at_least_one_key(
ATTR_DEVICE_ID, ATTR_ENTITY_ID, ATTR_AREA_ID
),
Expand All @@ -453,13 +427,7 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
**TARGET_VALIDATORS,
vol.Required(const.ATTR_COMMAND_CLASS): vol.All(
vol.Coerce(int), vol.Coerce(CommandClass)
),
Expand All @@ -483,13 +451,7 @@ def validate_entities(val: dict[str, Any]) -> dict[str, Any]:
schema=vol.Schema(
vol.All(
{
vol.Optional(ATTR_AREA_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_DEVICE_ID): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(ATTR_ENTITY_ID): cv.entity_ids,
**TARGET_VALIDATORS,
vol.Required(const.ATTR_NOTIFICATION_TYPE): vol.All(
vol.Coerce(int), vol.Coerce(NotificationType)
),
Expand Down
Loading

0 comments on commit fc1ed7d

Please sign in to comment.