Skip to content

Commit

Permalink
add timezone name validation
Browse files Browse the repository at this point in the history
  • Loading branch information
07pepa committed Jul 3, 2024
1 parent 1cddee1 commit de4199f
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ lint:

.PHONY: mypy
mypy:
mypy pydantic_extra_types
@mypy pydantic_extra_types

.PHONY: test
test:
Expand Down
139 changes: 139 additions & 0 deletions pydantic_extra_types/timezone_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Time zone name validation and serialization module."""

from __future__ import annotations

import importlib
import sys
import warnings
from typing import Any, List

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic_core import PydanticCustomError, core_schema


def _is_available(name: str) -> bool:
"""Check if a module is available for import."""
try:
importlib.import_module(name)
return True
except ModuleNotFoundError:
return False


def _tz_provider_from_zoneinfo() -> set[str]:
"""Get timezones from the zoneinfo module."""
from zoneinfo import available_timezones

return set(available_timezones())


def _tz_provider_from_pytz() -> set[str]:
"""Get timezones from the pytz module."""
from pytz import all_timezones

return set(all_timezones)


def _warn_about_pytz_usage() -> None:
"""Warn about using pytz with Python 3.9 or later."""
warnings.warn(
'Projects using Python 3.9 or later should be using the support now included as part of the standard library zone-info. '
'Please consider switching to the standard library module.'
)


def get_timezones() -> set[str]:
"""Determine the timezone provider and return available timezones."""
if _is_available('zoneinfo') and _is_available('tzdata'):
return _tz_provider_from_zoneinfo()
elif _is_available('pytz'):
if sys.version_info[:2] > (3, 8):
_warn_about_pytz_usage()
return _tz_provider_from_pytz()
else:
if sys.version_info[:2] == (3, 8):
raise ImportError('No pytz module found. Please install it with "pip install pytz"')
raise ImportError('No timezone provider found. Please install tzdata with "pip install tzdata"')


class TimeZoneNameSettings(type):
def __new__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def]
dct['strict'] = kwargs.pop('strict', True)
return super().__new__(cls, name, bases, dct)

def __init__(cls, name, bases, dct, **kwargs): # type: ignore[no-untyped-def]
super().__init__(name, bases, dct)
cls.strict = kwargs.get('strict', True)


class TimeZoneName(str, metaclass=TimeZoneNameSettings): # type: ignore[misc]
"""If the mode is not strict matching, it is case-insensitive with whitespace stripped.
Value is then coerced to the correct case."""

__slots__: List[str] = []
allowed_values = set(get_timezones())
allowed_values_list = list(allowed_values)
allowed_values_list.sort()
allowed_values_upper_to_correct = {val.upper(): val for val in allowed_values}

@classmethod
def _validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> TimeZoneName:
"""
Validate a time zone name from the provided str value.
Args:
__input_value: The str value to be validated.
_: The Pydantic ValidationInfo.
Returns:
The validated time zone name.
Raises:
PydanticCustomError: If the timezone name is not valid.
"""
if __input_value not in cls.allowed_values: # be fast for the most common case
if not cls.strict:
upper_value = __input_value.strip().upper()
if upper_value in cls.allowed_values_upper_to_correct:
return cls(cls.allowed_values_upper_to_correct[upper_value])
raise PydanticCustomError('TimeZoneName', 'Invalid timezone name.')
return cls(__input_value)

@classmethod
def __get_pydantic_core_schema__(
cls, _: type[Any], __: GetCoreSchemaHandler
) -> core_schema.AfterValidatorFunctionSchema:
"""
Return a Pydantic CoreSchema with the ISO 639-3 language code validation.
Args:
_: The source type.
__: The handler to get the CoreSchema.
Returns:
A Pydantic CoreSchema with the ISO 639-3 language code validation.
"""
return core_schema.with_info_after_validator_function(
cls._validate,
core_schema.str_schema(min_length=1),
)

@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> dict[str, Any]:
"""
Return a Pydantic JSON Schema with the ISO 639-3 language code validation.
Args:
schema: The Pydantic CoreSchema.
handler: The handler to get the JSON Schema.
Returns:
A Pydantic JSON Schema with the ISO 639-3 language code validation.
"""
json_schema = handler(schema)
json_schema.update({'enum': cls.allowed_values_list})
return json_schema
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ all = [
'semver>=3.0.2',
'python-ulid>=1,<2; python_version<"3.9"',
'python-ulid>=1,<3; python_version>="3.9"',
'pendulum>=3.0.0,<4.0.0'
'pendulum>=3.0.0,<4.0.0',
'pytz>=2024.1',
'tzdata>=2024.1',
]
phonenumbers = ['phonenumbers>=8,<9']
pycountry = ['pycountry>=23']
Expand Down
1 change: 1 addition & 0 deletions requirements/linting.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pre-commit
mypy
annotated-types
ruff
types-pytz
19 changes: 19 additions & 0 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pydantic_extra_types.pendulum_dt import DateTime
from pydantic_extra_types.script_code import ISO_15924
from pydantic_extra_types.semantic_version import SemanticVersion
from pydantic_extra_types.timezone_name import TimeZoneName
from pydantic_extra_types.ulid import ULID

languages = [lang.alpha_3 for lang in pycountry.languages]
Expand All @@ -36,6 +37,8 @@

scripts = [script.alpha_4 for script in pycountry.scripts]

timezone_names = TimeZoneName.allowed_values_list

everyday_currencies.sort()


Expand Down Expand Up @@ -335,6 +338,22 @@
'type': 'object',
},
),
(
TimeZoneName,
{
'properties': {
'x': {
'title': 'X',
'type': 'string',
'enum': timezone_names,
'minLength': 1,
}
},
'required': ['x'],
'title': 'Model',
'type': 'object',
},
),
],
)
def test_json_schema(cls, expected):
Expand Down
78 changes: 78 additions & 0 deletions tests/test_timezone_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import re

import pytest
import pytz
from pydantic import BaseModel, ValidationError

from pydantic_extra_types.timezone_name import TimeZoneName

has_zone_info = True
try:
from zoneinfo import available_timezones
except ImportError:
has_zone_info = False

pytz_zones_bad = [(zone.lower(), zone) for zone in pytz.all_timezones]
pytz_zones_bad.extend([(f' {zone}', zone) for zone in pytz.all_timezones_set])


class TZNameCheck(BaseModel):
timezone_name: TimeZoneName


class TZNonStrict(TimeZoneName, strict=False):
pass


class NonStrictTzName(BaseModel):
timezone_name: TZNonStrict


@pytest.mark.parametrize('zone', pytz.all_timezones)
def test_all_timezones_non_strict_pytz(zone):
assert TZNameCheck(timezone_name=zone).timezone_name == zone
assert NonStrictTzName(timezone_name=zone).timezone_name == zone


@pytest.mark.parametrize('zone', pytz_zones_bad)
def test_all_timezones_pytz_lower(zone):
assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1]


def test_fail_non_existing_timezone():
with pytest.raises(
ValidationError,
match=re.escape(
'1 validation error for TZNameCheck\n'
'timezone_name\n '
'Invalid timezone name. '
"[type=TimeZoneName, input_value='mars', input_type=str]"
),
):
TZNameCheck(timezone_name='mars')

with pytest.raises(
ValidationError,
match=re.escape(
'1 validation error for NonStrictTzName\n'
'timezone_name\n '
'Invalid timezone name. '
"[type=TimeZoneName, input_value='mars', input_type=str]"
),
):
NonStrictTzName(timezone_name='mars')


if has_zone_info:
zones = list(available_timezones())
zones.sort()
zones_bad = [(zone.lower(), zone) for zone in zones]

@pytest.mark.parametrize('zone', zones)
def test_all_timezones_zone_info(zone):
assert TZNameCheck(timezone_name=zone).timezone_name == zone
assert NonStrictTzName(timezone_name=zone).timezone_name == zone

@pytest.mark.parametrize('zone', zones_bad)
def test_all_timezones_zone_info_NonStrict(zone):
assert NonStrictTzName(timezone_name=zone[0]).timezone_name == zone[1]

0 comments on commit de4199f

Please sign in to comment.