Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type hints for methods #371

Merged
merged 15 commits into from
Jan 7, 2025
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 12.1.0 (unreleased)

Features:

- Typing: Add type hints for parser methods ([#367](https://github.com/sloria/environs/issues/367)).
Thanks [OkeyDev](https://github/OkeyDev) for the PR.

## 12.0.0 (2025-01-06)

Features:
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ classifiers = [
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.9"
dependencies = ["python-dotenv", "marshmallow>=3.13.0"]
dependencies = [
"python-dotenv",
"marshmallow>=3.13.0",
"typing-extensions; python_version < '3.11'",
]

[project.urls]
Changelog = "https://github.com/sloria/environs/blob/master/CHANGELOG.md"
Expand Down
88 changes: 50 additions & 38 deletions src/environs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,39 @@

import collections
import contextlib
import datetime as dt
import decimal
import functools
import inspect
import json as pyjson
import logging
import os
import re
import typing
import uuid
from collections.abc import Mapping
from datetime import timedelta
from enum import Enum
from pathlib import Path
from urllib.parse import ParseResult, urlparse

import marshmallow as ma
from dj_database_url import DBConfig
from dotenv.main import _walk_to_root, load_dotenv

from .types import (
DictFieldMethod,
EnumFuncMethod,
EnumT,
ErrorList,
ErrorMapping,
FieldFactory,
FieldMethod,
FieldOrFactory,
ListFieldMethod,
ParserMethod,
Subcast,
)

if typing.TYPE_CHECKING:
try:
from dj_database_url import DBConfig
Expand All @@ -29,15 +46,9 @@
_T = typing.TypeVar("_T")
_StrType = str
_BoolType = bool

ErrorMapping = typing.Mapping[str, list[str]]
ErrorList = list[str]
FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[type, typing.Callable[..., _T], ma.fields.Field]
FieldType = type[ma.fields.Field]
FieldOrFactory = typing.Union[FieldType, FieldFactory]
ParserMethod = typing.Callable[..., typing.Any]

_IntType = int
_ListType = list
_DictType = dict

_EXPANDED_VAR_PATTERN = re.compile(r"(?<!\\)\$\{([A-Za-z0-9_]+)(:-[^\}:]*)?\}")
# Ordered duration strings, loosely based on the [GEP-2257](https://gateway-api.sigs.k8s.io/geps/gep-2257/) spec
Expand Down Expand Up @@ -91,12 +102,12 @@ def _field2method(
*,
preprocess: typing.Callable | None = None,
preprocess_kwarg_names: typing.Sequence[str] = tuple(),
) -> ParserMethod:
) -> typing.Any:
def method(
self: Env,
name: str,
default: typing.Any = ma.missing,
subcast: Subcast | None = None,
subcast: Subcast[_T] | None = None,
*,
# Subset of relevant marshmallow.Field kwargs
load_default: typing.Any = ma.missing,
Expand Down Expand Up @@ -161,13 +172,13 @@ def method(
self._errors[parsed_key].extend(error.messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method


def _func2method(func: typing.Callable, method_name: str) -> ParserMethod:
def _func2method(func: typing.Callable[..., _T], method_name: str) -> typing.Any:
def method(
self: Env,
name: str,
Expand Down Expand Up @@ -209,7 +220,7 @@ def method(
self._errors[parsed_key].extend(messages)
else:
self._values[parsed_key] = value
return value
return typing.cast(typing.Optional[_T], value)

method.__name__ = method_name
return method
Expand Down Expand Up @@ -292,10 +303,7 @@ def _preprocess_json(value: str | typing.Mapping | list, **kwargs):
raise ma.ValidationError("Not valid JSON.") from error


_EnumT = typing.TypeVar("_EnumT", bound=Enum)


def _enum_parser(value, type: type[_EnumT], ignore_case: bool = False) -> _EnumT:
def _enum_parser(value, type: type[EnumT], ignore_case: bool = False) -> EnumT:
if isinstance(value, type):
return value

Expand Down Expand Up @@ -371,7 +379,7 @@ def deserialize( # type: ignore[override]
data: typing.Mapping[str, typing.Any] | None = None,
**kwargs,
) -> ParseResult:
ret = super().deserialize(value, attr, data, **kwargs)
ret = typing.cast(str, super().deserialize(value, attr, data, **kwargs))
return urlparse(ret)


Expand Down Expand Up @@ -423,20 +431,20 @@ def _deserialize(self, value, *args, **kwargs) -> timedelta:
class Env:
"""An environment variable reader."""

__call__: ParserMethod = _field2method(ma.fields.Raw, "__call__")
__call__: FieldMethod[typing.Any] = _field2method(ma.fields.Raw, "__call__")

int = _field2method(ma.fields.Int, "int")
bool = _field2method(ma.fields.Bool, "bool")
str = _field2method(ma.fields.Str, "str")
float = _field2method(ma.fields.Float, "float")
decimal = _field2method(ma.fields.Decimal, "decimal")
list = _field2method(
int: FieldMethod[int] = _field2method(ma.fields.Int, "int")
bool: FieldMethod[bool] = _field2method(ma.fields.Bool, "bool")
str: FieldMethod[str] = _field2method(ma.fields.Str, "str")
float: FieldMethod[float] = _field2method(ma.fields.Float, "float")
decimal: FieldMethod[decimal.Decimal] = _field2method(ma.fields.Decimal, "decimal")
list: ListFieldMethod = _field2method(
_make_list_field,
"list",
preprocess=_preprocess_list,
preprocess_kwarg_names=("subcast", "delimiter"),
)
dict = _field2method(
dict: DictFieldMethod = _field2method(
ma.fields.Dict,
"dict",
preprocess=_preprocess_dict,
Expand All @@ -448,16 +456,20 @@ class Env:
"delimiter",
),
)
json = _field2method(ma.fields.Field, "json", preprocess=_preprocess_json)
datetime = _field2method(ma.fields.DateTime, "datetime")
date = _field2method(ma.fields.Date, "date")
time = _field2method(ma.fields.Time, "time")
path = _field2method(_PathField, "path")
log_level = _field2method(_LogLevelField, "log_level")
timedelta = _field2method(_TimeDeltaField, "timedelta")
uuid = _field2method(ma.fields.UUID, "uuid")
url = _field2method(_URLField, "url")
enum = _func2method(_enum_parser, "enum")
json: FieldMethod[_ListType | _DictType] = _field2method(
ma.fields.Field, "json", preprocess=_preprocess_json
)
datetime: FieldMethod[dt.datetime] = _field2method(ma.fields.DateTime, "datetime")
date: FieldMethod[dt.date] = _field2method(ma.fields.Date, "date")
time: FieldMethod[dt.time] = _field2method(ma.fields.Time, "time")
timedelta: FieldMethod[dt.timedelta] = _field2method(_TimeDeltaField, "timedelta")
path: FieldMethod[Path] = _field2method(_PathField, "path")
log_level: FieldMethod[_IntType] = _field2method(_LogLevelField, "log_level")

uuid: FieldMethod[uuid.UUID] = _field2method(ma.fields.UUID, "uuid")
url: FieldMethod[ParseResult] = _field2method(_URLField, "url")

enum: EnumFuncMethod = _func2method(_enum_parser, "enum")
dj_db_url = _func2method(_dj_db_url_parser, "dj_db_url")
dj_email_url = _func2method(_dj_email_url_parser, "dj_email_url")
dj_cache_url = _func2method(_dj_cache_url_parser, "dj_cache_url")
Expand Down
89 changes: 89 additions & 0 deletions src/environs/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Custom types and type aliases.

.. warning::

This module is provisional. Types may be modified, added, and removed between minor releases.
"""

from __future__ import annotations

import enum
import typing

try:
from typing import Unpack
except ImportError: # Remove when dropping Python 3.10
from typing_extensions import Unpack

import marshmallow as ma

T = typing.TypeVar("T")
EnumT = typing.TypeVar("EnumT", bound=enum.Enum)


ErrorMapping = typing.Mapping[str, list[str]]
ErrorList = list[str]
FieldFactory = typing.Callable[..., ma.fields.Field]
Subcast = typing.Union[type, typing.Callable[..., T], ma.fields.Field]
FieldType = type[ma.fields.Field]
FieldOrFactory = typing.Union[FieldType, FieldFactory]
ParserMethod = typing.Callable[..., T]


class BaseMethodKwargs(typing.TypedDict, total=False):
# Subset of relevant marshmallow.Field kwargs shared by all parser methods
load_default: typing.Any
validate: (
typing.Callable[[typing.Any], typing.Any]
| typing.Iterable[typing.Callable[[typing.Any], typing.Any]]
| None
)
required: bool
allow_none: bool | None
error_messages: dict[str, str] | None
metadata: typing.Mapping[str, typing.Any] | None


class FieldMethod(typing.Generic[T]):
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: Subcast[T] | None = None,
**kwargs: Unpack[BaseMethodKwargs],
) -> T | None: ...


class ListFieldMethod:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
subcast: Subcast[T] | None = None,
*,
delimiter: str | None = None,
**kwargs: Unpack[BaseMethodKwargs],
) -> list | None: ...


class DictFieldMethod:
def __call__(
self,
name: str,
default: typing.Any = ma.missing,
*,
subcast_keys: Subcast[T] | None = None,
subcast_values: Subcast[T] | None = None,
delimiter: str | None = None,
**kwargs: Unpack[BaseMethodKwargs],
) -> dict | None: ...


class EnumFuncMethod:
def __call__(
self,
value,
type: type[EnumT],
default: EnumT | None = None,
ignore_case: bool = False,
) -> EnumT | None: ...
41 changes: 41 additions & 0 deletions tests/mypy_test_cases/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Test cases for type hints of environs.Env.

To run these, use: ::

tox -e mypy-marshmallow3

Or ::

tox -e mypy-marshmallowdev
"""

import datetime as dt
import decimal
import pathlib
import uuid
from typing import Any
from urllib.parse import ParseResult

import environs

env = environs.Env()

A: int | None = env.int("FOO", None)
B: bool | None = env.bool("FOO", None)
C: str | None = env.str("FOO", None)
D: float | None = env.float("FOO", None)
E: decimal.Decimal | None = env.decimal("FOO", None)
F: list | None = env.list("FOO", None)
G: list[int] | None = env.list("FOO", None, subcast=int)
H: dict | None = env.dict("FOO", None)
J: dict[str, int] | None = env.dict("FOO", None, subcast_keys=str, subcast_values=int)
K: list | dict | None = env.json("FOO", None)
L: dt.datetime | None = env.datetime("FOO", None)
M: dt.date | None = env.date("FOO", None)
N: dt.time | None = env.time("FOO", None)
P: dt.timedelta | None = env.timedelta("FOO", None)
Q: pathlib.Path | None = env.path("FOO", None)
R: int | None = env.log_level("FOO", None)
S: uuid.UUID | None = env.uuid("FOO", None)
T: ParseResult | None = env.url("FOO", None)
U: Any = env("FOO", None)
Loading