diff --git a/Makefile b/Makefile index f8d968f7..af890d57 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ lint: .PHONY: mypy mypy: - mypy pydantic_extra_types + @mypy pydantic_extra_types .PHONY: test test: diff --git a/pydantic_extra_types/timezone_name.py b/pydantic_extra_types/timezone_name.py new file mode 100644 index 00000000..cc092b69 --- /dev/null +++ b/pydantic_extra_types/timezone_name.py @@ -0,0 +1,139 @@ +"""Time zone name validation and serialization module.""" + +from __future__ import annotations + +import importlib +import sys +import warnings +from typing import Any, List + +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import PydanticCustomError, core_schema + + +def _is_available(name: str) -> bool: + """Check if a module is available for import.""" + try: + importlib.import_module(name) + return True + except ModuleNotFoundError: + return False + + +def _tz_provider_from_zoneinfo() -> set[str]: + """Get timezones from the zoneinfo module.""" + from zoneinfo import available_timezones + + return set(available_timezones()) + + +def _tz_provider_from_pytz() -> set[str]: + """Get timezones from the pytz module.""" + from pytz import all_timezones + + return set(all_timezones) + + +def _warn_about_pytz_usage() -> None: + """Warn about using pytz with Python 3.9 or later.""" + warnings.warn( + 'Projects using Python 3.9 or later should be using the support now included as part of the standard library zone-info. ' + 'Please consider switching to the standard library module.' + ) + + +def get_timezones() -> set[str]: + """Determine the timezone provider and return available timezones.""" + if _is_available('zoneinfo') and _is_available('tzdata'): + return _tz_provider_from_zoneinfo() + elif _is_available('pytz'): + if sys.version_info[:2] > (3, 8): + _warn_about_pytz_usage() + return _tz_provider_from_pytz() + else: + if sys.version_info[:2] == (3, 8): + raise ImportError('No pytz module found. Please install it with "pip install pytz"') + raise ImportError('No timezone provider found. Please install tzdata with "pip install tzdata"') + + +class TimeZoneNameSettings(type): + def __new__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] + dct['strict'] = kwargs.pop('strict', True) + return super().__new__(cls, name, bases, dct) + + def __init__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def] + super().__init__(name, bases, dct) + cls.strict = kwargs.get('strict', True) + + +class TimeZoneName(str, metaclass=TimeZoneNameSettings): # type: ignore[misc] + """If the mode is not strict matching, it is case-insensitive with whitespace stripped. + Value is then coerced to the correct case.""" + + __slots__: List[str] = [] + allowed_values = set(get_timezones()) + allowed_values_list = list(allowed_values) + allowed_values_list.sort() + allowed_values_upper_to_correct = {val.upper(): val for val in allowed_values} + + @classmethod + def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName: + """ + Validate a time zone name from the provided str value. + + Args: + __input_value: The str value to be validated. + _: The Pydantic ValidationInfo. + + Returns: + The validated time zone name. + + Raises: + PydanticCustomError: If the timezone name is not valid. + """ + if __input_value not in cls.allowed_values: # be fast for the most common case + if not cls.strict: + upper_value = __input_value.strip().upper() + if upper_value in cls.allowed_values_upper_to_correct: + return cls(cls.allowed_values_upper_to_correct[upper_value]) + raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.') + return cls(__input_value) + + @classmethod + def __get_pydantic_core_schema__( + cls, _: type[Any], __: GetCoreSchemaHandler + ) -> core_schema.AfterValidatorFunctionSchema: + """ + Return a Pydantic CoreSchema with the ISO 639-3 language code validation. + + Args: + _: The source type. + __: The handler to get the CoreSchema. + + Returns: + A Pydantic CoreSchema with the ISO 639-3 language code validation. + + """ + return core_schema.with_info_after_validator_function( + cls._validate, + core_schema.str_schema(min_length=1), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> dict[str, Any]: + """ + Return a Pydantic JSON Schema with the ISO 639-3 language code validation. + + Args: + schema: The Pydantic CoreSchema. + handler: The handler to get the JSON Schema. + + Returns: + A Pydantic JSON Schema with the ISO 639-3 language code validation. + + """ + json_schema = handler(schema) + json_schema.update({'enum': cls.allowed_values_list}) + return json_schema diff --git a/pyproject.toml b/pyproject.toml index eadf5103..824db8cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,9 @@ all = [ 'semver>=3.0.2', 'python-ulid>=1,<2; python_version<"3.9"', 'python-ulid>=1,<3; python_version>="3.9"', - 'pendulum>=3.0.0,<4.0.0' + 'pendulum>=3.0.0,<4.0.0', + 'pytz>=2024.1', + 'tzdata>=2024.1', ] phonenumbers = ['phonenumbers>=8,<9'] pycountry = ['pycountry>=23'] diff --git a/requirements/linting.in b/requirements/linting.in index 06a5fced..76054199 100644 --- a/requirements/linting.in +++ b/requirements/linting.in @@ -2,3 +2,4 @@ pre-commit mypy annotated-types ruff +types-pytz \ No newline at end of file diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 303098c1..c39c88f0 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -19,6 +19,7 @@ from pydantic_extra_types.pendulum_dt import DateTime from pydantic_extra_types.script_code import ISO_15924 from pydantic_extra_types.semantic_version import SemanticVersion +from pydantic_extra_types.timezone_name import TimeZoneName from pydantic_extra_types.ulid import ULID languages = [lang.alpha_3 for lang in pycountry.languages] @@ -36,6 +37,8 @@ scripts = [script.alpha_4 for script in pycountry.scripts] +timezone_names = TimeZoneName.allowed_values_list + everyday_currencies.sort() @@ -335,6 +338,22 @@ 'type': 'object', }, ), + ( + TimeZoneName, + { + 'properties': { + 'x': { + 'title': 'X', + 'type': 'string', + 'enum': timezone_names, + 'minLength': 1, + } + }, + 'required': ['x'], + 'title': 'Model', + 'type': 'object', + }, + ), ], ) def test_json_schema(cls, expected): diff --git a/tests/test_timezone_names.py b/tests/test_timezone_names.py new file mode 100644 index 00000000..471ca27b --- /dev/null +++ b/tests/test_timezone_names.py @@ -0,0 +1,78 @@ +import re + +import pytest +import pytz +from pydantic import BaseModel, ValidationError + +from pydantic_extra_types.timezone_name import TimeZoneName + +has_zone_info = True +try: + from zoneinfo import available_timezones +except ImportError: + has_zone_info = False + +pytz_zones_bad = [(zone.lower(), zone) for zone in pytz.all_timezones] +pytz_zones_bad.extend([(f' {zone}', zone) for zone in pytz.all_timezones_set]) + + +class TZNameCheck(BaseModel): + timezone_name: TimeZoneName + + +class TZNonStrict(TimeZoneName, strict=False): + pass + + +class NonStrictTzName(BaseModel): + timezone_name: TZNonStrict + + +@pytest.mark.parametrize('zone', pytz.all_timezones) +def test_all_timezones_non_strict_pytz(zone): + assert TZNameCheck(timezone_name=zone).timezone_name == zone + assert NonStrictTzName(timezone_name=zone).timezone_name == zone + + +@pytest.mark.parametrize('zone', pytz_zones_bad) +def test_all_timezones_pytz_lower(zone): + assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1] + + +def test_fail_non_existing_timezone(): + with pytest.raises( + ValidationError, + match=re.escape( + '1 validation error for TZNameCheck\n' + 'timezone_name\n ' + 'Invalid timezone name. ' + "[type=TimeZoneName, input_value='mars', input_type=str]" + ), + ): + TZNameCheck(timezone_name='mars') + + with pytest.raises( + ValidationError, + match=re.escape( + '1 validation error for NonStrictTzName\n' + 'timezone_name\n ' + 'Invalid timezone name. ' + "[type=TimeZoneName, input_value='mars', input_type=str]" + ), + ): + NonStrictTzName(timezone_name='mars') + + +if has_zone_info: + zones = list(available_timezones()) + zones.sort() + zones_bad = [(zone.lower(), zone) for zone in zones] + + @pytest.mark.parametrize('zone', zones) + def test_all_timezones_zone_info(zone): + assert TZNameCheck(timezone_name=zone).timezone_name == zone + assert NonStrictTzName(timezone_name=zone).timezone_name == zone + + @pytest.mark.parametrize('zone', zones_bad) + def test_all_timezones_zone_info_NonStrict(zone): + assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1]