Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,80 @@
import enum

import pytest
from zigpy.quirks.v2.homeassistant import UnitOfPower as QuirksUnitOfPower
from zigpy.quirks.v2.homeassistant import PERCENTAGE, UnitOfPower as QuirksUnitOfPower

from zha.units import UnitOfPower, validate_unit
from zha.units import InvalidUnitOfMeasureException, UnitOfPower, validate_unit


def test_unit_validation() -> None:
"""Test unit validation."""
class NonQuirkUnitEnum(enum.Enum):
"""Non quirk unit enum."""

assert validate_unit(QuirksUnitOfPower.WATT) == UnitOfPower.WATT
ValidUnitPercentage = "%"
InvalidUnitString = "fakeValue"
InvalidUnitInt = 24
InvalidUnitNone = None

class FooUnit(enum.Enum):
"""Foo unit."""

BAR = "bar"
class UnitOfMass(enum.Enum):
"""Unit of mass."""

class UnitOfMass(enum.Enum):
"""UnitOfMass."""
ValidUnitPercentage = "%"
InvalidUnitString = "fakeValue"

BAR = "bar"

with pytest.raises(KeyError):
validate_unit(FooUnit.BAR)
@pytest.mark.parametrize(
"inputUnit,expectedUnitResponse",
[
(QuirksUnitOfPower.WATT, UnitOfPower.WATT),
(NonQuirkUnitEnum.ValidUnitPercentage, PERCENTAGE),
(UnitOfMass.ValidUnitPercentage, PERCENTAGE),
],
)
def test_valid_enum_unit_return_unit_as_string(inputUnit, expectedUnitResponse) -> None:
"""Test validate_unit with valid unit returning unit."""

with pytest.raises(ValueError):
validate_unit(UnitOfMass.BAR)
validatedUnit = validate_unit(inputUnit)

assert validatedUnit == expectedUnitResponse
assert isinstance(validatedUnit, str)


@pytest.mark.parametrize(
"inputUnit,expectedUnitResponse",
[
("W", UnitOfPower.WATT),
("%", PERCENTAGE),
],
)
def test_valid_string_unit_return_unit_as_string(
inputUnit, expectedUnitResponse
) -> None:
"""Test validate_unit with valid unit returning unit."""

validatedUnit = validate_unit(inputUnit)

assert validatedUnit == expectedUnitResponse
assert isinstance(validatedUnit, str)


@pytest.mark.parametrize(
"inputUnit",
[
NonQuirkUnitEnum.InvalidUnitString,
NonQuirkUnitEnum.InvalidUnitInt,
NonQuirkUnitEnum.InvalidUnitNone,
UnitOfMass.InvalidUnitString,
"fakeunit",
42,
None,
"% ", # Contains a valid unit, but has space
"kwh", # Is a valid unit, but invalid for casing.
"WATT", # Matches UnitOfPower.WATT Enum, not correct method to provide unit.
"UnitOfPower.WATT", # Matches UnitOfPower.WATT Enum, not correct method to provide unit.
],
)
def test_invalid_unit_exception_raised(inputUnit) -> None:
"""Test validate_unit with invalid unit raising InvalidUnitOfMeasureException."""

with pytest.raises(InvalidUnitOfMeasureException):
assert validate_unit(inputUnit)
4 changes: 1 addition & 3 deletions zha/application/platforms/number/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,7 @@ def _init_from_quirks_metadata(self, entity_metadata: NumberMetadata) -> None:
_LOGGER,
)
if entity_metadata.unit is not None:
self._attr_native_unit_of_measurement = validate_unit(
entity_metadata.unit
).value
self._attr_native_unit_of_measurement = validate_unit(entity_metadata.unit)

@functools.cached_property
def info_object(self) -> NumberConfigurationEntityInfo:
Expand Down
4 changes: 1 addition & 3 deletions zha/application/platforms/sensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,7 @@ def _init_from_quirks_metadata(self, entity_metadata: ZCLSensorMetadata) -> None
_LOGGER,
)
if entity_metadata.unit is not None:
self._attr_native_unit_of_measurement = validate_unit(
entity_metadata.unit
).value
self._attr_native_unit_of_measurement = validate_unit(entity_metadata.unit)

@functools.cached_property
def info_object(self) -> SensorEntityInfo:
Expand Down
65 changes: 45 additions & 20 deletions zha/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from typing import Final


class InvalidUnitOfMeasureException(Exception):
"""Exception for invalid unit of measure."""

pass


class UnitOfTemperature(StrEnum):
"""Temperature units."""

Expand Down Expand Up @@ -158,24 +164,43 @@ class UnitOfEnergy(StrEnum):
# Percentage units
PERCENTAGE: Final[str] = "%"


UNITS_OF_MEASURE = {
UnitOfApparentPower.__name__: UnitOfApparentPower,
UnitOfPower.__name__: UnitOfPower,
UnitOfEnergy.__name__: UnitOfEnergy,
UnitOfElectricCurrent.__name__: UnitOfElectricCurrent,
UnitOfElectricPotential.__name__: UnitOfElectricPotential,
UnitOfTemperature.__name__: UnitOfTemperature,
UnitOfTime.__name__: UnitOfTime,
UnitOfFrequency.__name__: UnitOfFrequency,
UnitOfPressure.__name__: UnitOfPressure,
UnitOfVolume.__name__: UnitOfVolume,
UnitOfVolumeFlowRate.__name__: UnitOfVolumeFlowRate,
UnitOfLength.__name__: UnitOfLength,
UnitOfMass.__name__: UnitOfMass,
}


def validate_unit(external_unit: Enum) -> Enum:
UNITS_OF_MEASURE_SET = frozenset(
set(UnitOfApparentPower._value2member_map_.keys())
| set(UnitOfPower._value2member_map_.keys())
| set(UnitOfEnergy._value2member_map_.keys())
| set(UnitOfElectricCurrent._value2member_map_.keys())
| set(UnitOfElectricPotential._value2member_map_.keys())
| set(UnitOfTemperature._value2member_map_.keys())
| set(UnitOfTime._value2member_map_.keys())
| set(UnitOfFrequency._value2member_map_.keys())
| set(UnitOfPressure._value2member_map_.keys())
| set(UnitOfVolume._value2member_map_.keys())
| set(UnitOfVolumeFlowRate._value2member_map_.keys())
| set(UnitOfLength._value2member_map_.keys())
| set(UnitOfMass._value2member_map_.keys())
| {
CONCENTRATION_MICROGRAMS_PER_CUBIC_METER,
CONCENTRATION_MILLIGRAMS_PER_CUBIC_METER,
CONCENTRATION_MICROGRAMS_PER_CUBIC_FOOT,
CONCENTRATION_PARTS_PER_CUBIC_METER,
CONCENTRATION_PARTS_PER_MILLION,
CONCENTRATION_PARTS_PER_BILLION,
SIGNAL_STRENGTH_DECIBELS_MILLIWATT,
SIGNAL_STRENGTH_DECIBELS,
LIGHT_LUX,
PERCENTAGE,
}
)


def validate_unit(unit: str | Enum) -> str:
"""Validate and return a unit of measure."""
return UNITS_OF_MEASURE[type(external_unit).__name__](external_unit.value)

check_unit = unit.value if isinstance(unit, Enum) else unit

if check_unit in UNITS_OF_MEASURE_SET:
return check_unit

raise InvalidUnitOfMeasureException(
f"Invalid unit of measurement: '{check_unit}'. Valid units are: {', '.join(f"'{unit_of_measure}'" for unit_of_measure in UNITS_OF_MEASURE_SET)}."
)