diff --git a/nicegui/elements/mixins/validation_element.py b/nicegui/elements/mixins/validation_element.py index 1f0e5ca1d..54193e6f8 100644 --- a/nicegui/elements/mixins/validation_element.py +++ b/nicegui/elements/mixins/validation_element.py @@ -1,13 +1,18 @@ -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Union from typing_extensions import Self +from ... import background_tasks +from ...helpers import is_coroutine_function from .value_element import ValueElement class ValidationElement(ValueElement): - def __init__(self, validation: Optional[Union[Callable[..., Optional[str]], Dict[str, Callable[..., bool]]]], **kwargs: Any) -> None: + def __init__(self, validation: Optional[Union[ + Callable[..., Union[Optional[str], Awaitable[Optional[str]]]], + Dict[str, Callable[..., bool]], + ]], **kwargs: Any) -> None: self._validation = validation self._auto_validation = True self._error: Optional[str] = None @@ -15,18 +20,24 @@ def __init__(self, validation: Optional[Union[Callable[..., Optional[str]], Dict self._props['error'] = None if validation is None else False # NOTE: reserve bottom space for error message @property - def validation(self) -> Optional[Union[Callable[..., Optional[str]], Dict[str, Callable[..., bool]]]]: + def validation(self) -> Optional[Union[ + Callable[..., Union[Optional[str], Awaitable[Optional[str]]]], + Dict[str, Callable[..., bool]], + ]]: """The validation function or dictionary of validation functions.""" return self._validation @validation.setter - def validation(self, validation: Optional[Union[Callable[..., Optional[str]], Dict[str, Callable[..., bool]]]]) -> None: + def validation(self, validation: Optional[Union[ + Callable[..., Union[Optional[str], Awaitable[Optional[str]]]], + Dict[str, Callable[..., bool]], + ]]) -> None: """Sets the validation function or dictionary of validation functions. :param validation: validation function or dictionary of validation functions (``None`` to disable validation) """ self._validation = validation - self.validate() + self.validate(return_result=False) @property def error(self) -> Optional[str]: @@ -47,13 +58,26 @@ def error(self, error: Optional[str]) -> None: self._props['error-message'] = error self.update() - def validate(self) -> bool: + def validate(self, *, return_result: bool = True) -> bool: """Validate the current value and set the error message if necessary. :return: True if the value is valid, False otherwise """ + if is_coroutine_function(self._validation): + async def await_error(): + assert callable(self._validation) + result = self._validation(self.value) + assert isinstance(result, Awaitable) + self.error = await result + if return_result: + raise NotImplementedError('The validate method cannot return results for async validation functions') + background_tasks.create(await_error()) + return True + if callable(self._validation): - self.error = self._validation(self.value) + result = self._validation(self.value) + assert not isinstance(result, Awaitable) + self.error = result return self.error is None if isinstance(self._validation, dict): @@ -73,4 +97,4 @@ def without_auto_validation(self) -> Self: def _handle_value_change(self, value: Any) -> None: super()._handle_value_change(value) if self._auto_validation: - self.validate() + self.validate(return_result=False)