diff --git a/Dockerfile b/Dockerfile index 81e7019..4b81054 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Install pytest python library as well as add all files in current directory -FROM python:3.7 AS base +FROM python:3.11 AS base WORKDIR /usr/src/app RUN apt-get update \ && apt-get install -y enchant \ diff --git a/jsonmodels/builders.py b/jsonmodels/builders.py index a19efdc..4437d62 100644 --- a/jsonmodels/builders.py +++ b/jsonmodels/builders.py @@ -1,79 +1,95 @@ """Builders to generate in memory representation of model and fields tree.""" -from __future__ import absolute_import from collections import defaultdict - +from typing import Any, Dict, List, Optional, Set import six from . import errors -from .fields import NotSet - +from .fields import NotSet, Value +from .types import Builder, Field, JSONSchemaProperty, JSONSchemaTypeName, Model -class Builder(object): - def __init__(self, parent=None, nullable=False, default=NotSet): +class BaseBuilder: + def __init__( + self, + parent: Optional[Builder] = None, + nullable: bool = False, + default: Any = NotSet, + ) -> None: self.parent = parent - self.types_builders = {} - self.types_count = defaultdict(int) - self.definitions = set() + self.types_builders: Dict[type[Model], Builder] = {} + self.types_count: Dict[type[Model], int] = defaultdict(int) + self.definitions: Set[Builder] = set() self.nullable = nullable self.default = default @property - def has_default(self): + def has_default(self) -> bool: return self.default is not NotSet - def register_type(self, type, builder): + def register_type(self, model_type: type[Model], builder: Builder) -> None: if self.parent: - return self.parent.register_type(type, builder) + self.parent.register_type(model_type, builder) + return - self.types_count[type] += 1 - if type not in self.types_builders: - self.types_builders[type] = builder + self.types_count[model_type] += 1 + if model_type not in self.types_builders: + self.types_builders[model_type] = builder - def get_builder(self, type): + def get_builder(self, model_type: type[Model]) -> Builder: if self.parent: - return self.parent.get_builder(type) + return self.parent.get_builder(model_type) - return self.types_builders[type] + return self.types_builders[model_type] - def count_type(self, type): + def count_type(self, model_type: type[Model]) -> int: if self.parent: - return self.parent.count_type(type) + return self.parent.count_type(model_type) - return self.types_count[type] + return self.types_count[model_type] @staticmethod - def maybe_build(value): + def maybe_build(value: Value) -> JSONSchemaProperty | Value: return value.build() if isinstance(value, Builder) else value - def add_definition(self, builder): + def add_definition(self, builder: Builder) -> None: if self.parent: return self.parent.add_definition(builder) self.definitions.add(builder) + def build_definition(self, add_definitions: bool = True) -> JSONSchemaProperty: + raise NotImplementedError() + + @property + def is_definition(self) -> bool: + raise NotImplementedError() + + @property + def type_name(self) -> str: + raise NotImplementedError() -class ObjectBuilder(Builder): - def __init__(self, model_type, *args, **kwargs): - super(ObjectBuilder, self).__init__(*args, **kwargs) - self.properties = {} - self.required = [] +class ObjectBuilder(BaseBuilder): + def __init__(self, model_type: type[Model], *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.properties: Dict[str, str | JSONSchemaProperty] = {} + self.required: List[str] = [] self.type = model_type self.register_type(self.type, self) - def add_field(self, name, field, schema): - _apply_validators_modifications(schema, field) + def add_field(self, name: str, field: Field, schema: str | JSONSchemaProperty) -> None: + if not isinstance(schema, str): + _apply_validators_modifications(schema, field) if isinstance(schema, dict) and field.help_text: schema["description"] = field.help_text self.properties[name] = schema if field.required: self.required.append(name) - def build(self): + def build(self) -> str | JSONSchemaProperty: builder = self.get_builder(self.type) if self.is_definition and not self.is_root: self.add_definition(builder) @@ -83,27 +99,27 @@ def build(self): return builder.build_definition() @property - def type_name(self): + def type_name(self) -> str: module_name = '{module}.{name}'.format( module=self.type.__module__, name=self.type.__name__, ) return module_name.replace('.', '_').lower() - def build_definition(self, add_definitions=True): - properties = dict( + def build_definition(self, add_definitions: bool = True) -> JSONSchemaProperty: + properties: Dict[str, str | JSONSchemaProperty] = dict( (name, self.maybe_build(value)) for name, value in self.properties.items() ) - schema = { + schema: JSONSchemaProperty = { 'type': 'object', 'additionalProperties': False, 'properties': properties, } if self.required: - schema['required'] = self.required + schema['required'] = list(self.required) if self.definitions and add_definitions: schema['definitions'] = dict( @@ -114,7 +130,7 @@ def build_definition(self, add_definitions=True): return schema @property - def is_definition(self): + def is_definition(self) -> bool: if self.count_type(self.type) > 1: return True elif self.parent: @@ -123,35 +139,30 @@ def is_definition(self): return False @property - def is_root(self): + def is_root(self) -> bool: return not bool(self.parent) -def _apply_validators_modifications(field_schema, field): +def _apply_validators_modifications(field_schema: JSONSchemaProperty, field: Field) -> None: for validator in field.validators: - try: + if hasattr(validator, "modify_schema"): validator.modify_schema(field_schema) - except AttributeError: - pass # arrays may have separate validators for each item. # we should also add those validators to the schema. if "items" in field_schema: for validator in field.item_validators: - try: + if hasattr(validator, "modify_schema"): validator.modify_schema(field_schema["items"]) - except AttributeError: - pass - - -class PrimitiveBuilder(Builder): - def __init__(self, type, *args, **kwargs): - super(PrimitiveBuilder, self).__init__(*args, **kwargs) - self.type = type +class PrimitiveBuilder(BaseBuilder): + def __init__(self, value_type: type, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.type = value_type - def build(self): - schema = {} + def build(self) -> JSONSchemaProperty: + obj_type: JSONSchemaTypeName + schema: JSONSchemaProperty = {} if issubclass(self.type, six.string_types): obj_type = 'string' elif issubclass(self.type, bool): @@ -174,19 +185,21 @@ def build(self): return schema -class ListBuilder(Builder): +class ListBuilder(BaseBuilder): - def __init__(self, *args, **kwargs): - super(ListBuilder, self).__init__(*args, **kwargs) - self.schemas = [] + parent: Builder - def add_type_schema(self, schema): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.schemas: list[Builder | JSONSchemaProperty] = [] + + def add_type_schema(self, schema: Builder | JSONSchemaProperty) -> None: self.schemas.append(schema) - def build(self): - schema = {'type': 'array'} + def build(self) -> str | JSONSchemaProperty: + schema: JSONSchemaProperty = {'type': 'array'} if self.nullable: - self.add_type_schema({'type': 'null'}) + self.add_type_schema({'type': 'null'}) # <- probably a bug if self.has_default: schema["default"] = [self.to_struct(i) for i in self.default] @@ -201,27 +214,28 @@ def build(self): return schema @property - def is_definition(self): + def is_definition(self) -> bool: return self.parent.is_definition @staticmethod - def to_struct(item): + def to_struct(item: Value) -> Value: from .models import Base if isinstance(item, Base): return item.to_struct() return item -class EmbeddedBuilder(Builder): +class EmbeddedBuilder(BaseBuilder): + parent: Builder - def __init__(self, *args, **kwargs): - super(EmbeddedBuilder, self).__init__(*args, **kwargs) - self.schemas = [] + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.schemas: list[Builder | JSONSchemaProperty] = [] - def add_type_schema(self, schema): + def add_type_schema(self, schema: Builder | JSONSchemaProperty) -> None: self.schemas.append(schema) - def build(self): + def build(self) -> JSONSchemaProperty: if self.nullable: self.add_type_schema({'type': 'null'}) @@ -239,5 +253,5 @@ def build(self): return schema @property - def is_definition(self): + def is_definition(self) -> bool: return self.parent.is_definition diff --git a/jsonmodels/collections.py b/jsonmodels/collections.py index 6c756eb..4dc34ee 100644 --- a/jsonmodels/collections.py +++ b/jsonmodels/collections.py @@ -1,4 +1,6 @@ - +from typing import Any, Iterable +from .types import CollectionField +from typing_extensions import override class ModelCollection(list): @@ -9,14 +11,16 @@ class ModelCollection(list): """ - def __init__(self, field): + def __init__(self, field: CollectionField) -> None: super(ModelCollection, self).__init__() self.field = field - def append(self, value): + @override + def append(self, value: Any) -> None: self.field.validate_single_value(value) super(ModelCollection, self).append(value) - def __setitem__(self, key, value): + @override + def __setitem__(self, index: Any, value: Any, /) -> None: self.field.validate_single_value(value) - super(ModelCollection, self).__setitem__(key, value) + super(ModelCollection, self).__setitem__(index, value) diff --git a/jsonmodels/errors.py b/jsonmodels/errors.py index 200584b..411079e 100644 --- a/jsonmodels/errors.py +++ b/jsonmodels/errors.py @@ -1,4 +1,6 @@ -from typing import List, Tuple, Type +from typing import Any, List, Sized, Tuple, Type + +from .types import EmbedType class ValidationError(RuntimeError): @@ -38,7 +40,7 @@ class FieldValidationError(ValidationError): Enriches a validator error with the name of the field that caused it. """ def __init__(self, model_name: str, field_name: str, - given_value: any, error: ValidatorError): + given_value: Any, error: ValidatorError): """ :param model_name: The name of the model. :param field_name: The name of the field. @@ -56,14 +58,14 @@ def __init__(self, model_name: str, field_name: str, class RequiredFieldError(ValidatorError): """ Error raised when a required field has no value """ - def __init__(self): + def __init__(self) -> None: super(RequiredFieldError, self).__init__('Field is required!') class RegexError(ValidatorError): """ Error raised by the Regex validator """ - def __init__(self, value: str, pattern: str): + def __init__(self, value: str, pattern: str) -> None: tpl = 'Value "{value}" did not match pattern "{pattern}".' super(RegexError, self).__init__(tpl.format( value=value, pattern=pattern @@ -78,7 +80,7 @@ class BadTypeError(ValidatorError): expected one """ - def __init__(self, value: any, types: Tuple, is_list: bool): + def __init__(self, value: Any, types: Tuple, is_list: bool) -> None: """ :param value: The given value. :param types: The accepted types. @@ -104,7 +106,7 @@ class AmbiguousTypeError(ValidatorError): that supports multiple types """ - def __init__(self, types: Tuple): + def __init__(self, types: tuple[EmbedType, ...]) -> None: """ The types that are allowed """ tpl = 'Cannot decide which type to choose from "{types}".' super(AmbiguousTypeError, self).__init__(tpl.format( @@ -116,7 +118,7 @@ def __init__(self, types: Tuple): class MinLengthError(ValidatorError): """ Error raised by the Length validator when too few items are present """ - def __init__(self, value: list, minimum_length: int): + def __init__(self, value: Sized, minimum_length: int) -> None: """ :param value: The given value. :param minimum_length: The minimum length expected. @@ -132,7 +134,7 @@ def __init__(self, value: list, minimum_length: int): class MaxLengthError(ValidatorError): """ Error raised by the Length validator when receiving too many items """ - def __init__(self, value: list, maximum_length: int): + def __init__(self, value: Sized, maximum_length: int) -> None: """ :param value: The given value. :param maximum_length: The maximum length expected. @@ -148,7 +150,7 @@ def __init__(self, value: list, maximum_length: int): class MinValidationError(ValidatorError): """ Error raised by the Min validator """ - def __init__(self, value, minimum_value, exclusive: bool): + def __init__(self, value: int | float, minimum_value: int | float, exclusive: bool) -> None: """ :param value: The given value. :param minimum_value: The minimum value allowed. @@ -167,7 +169,7 @@ def __init__(self, value, minimum_value, exclusive: bool): class MaxValidationError(ValidatorError): """ Error raised by the Max validator """ - def __init__(self, value, maximum_value, exclusive: bool): + def __init__(self, value: int | float, maximum_value: int | float, exclusive: bool) -> None: """ :param value: The given value. :param maximum_value: The maximum value allowed. @@ -186,7 +188,7 @@ def __init__(self, value, maximum_value, exclusive: bool): class EnumError(ValidatorError): """ Error raised by the Enum validator """ - def __init__(self, value: any, choices: List[any]): + def __init__(self, value: Any, choices: List[Any]) -> None: """ :param value: The given value. :param choices: The allowed choices. diff --git a/jsonmodels/fields.py b/jsonmodels/fields.py index 169830f..e3602c2 100755 --- a/jsonmodels/fields.py +++ b/jsonmodels/fields.py @@ -5,37 +5,35 @@ import re import six from dateutil.parser import parse -from typing import List, Optional, Dict, Set, Union, Pattern +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast from .collections import ModelCollection -from .errors import RequiredFieldError, BadTypeError, AmbiguousTypeError +from .errors import AmbiguousTypeError, BadTypeError, RequiredFieldError +from .types import BsonEncodable, EmbedType, Field, JSONValue, Model, PrimitiveTypeInstance, Validator, ValidatorFunction, ValidatorObject, Value + # unique marker for "no default value specified". None is not good enough since # it is a completely valid default value. NotSet = object() -# BSON compatible types, which can be returned by toBsonEncodable. -BsonEncodable = Union[ - float, str, object, Dict, List, bytes, bool, datetime.datetime, None, - Pattern, int, bytes -] - -class BaseField(object): +class BaseField: """Base class for all fields.""" - types = None + types: Tuple[Any, ...] = tuple() + validators: List[Validator] = [] def __init__( - self, - required=False, - nullable=False, - help_text=None, - validators=None, - default=NotSet, - name=None): - self.memory = WeakKeyDictionary() + self, + required: bool = False, + nullable: bool = False, + help_text: Optional[str] = None, + validators: Optional[List[Validator]] = None, + default: Value = NotSet, + name: Optional[str] = None, + ) -> None: + self.memory: WeakKeyDictionary = WeakKeyDictionary() self.required = required self.help_text = help_text self.nullable = nullable @@ -47,21 +45,24 @@ def __init__( self._default = default @property - def has_default(self): + def has_default(self) -> bool: return self._default is not NotSet - def _assign_validators(self, validators): - if validators and not isinstance(validators, list): - validators = [validators] - self.validators = validators or [] + def _assign_validators(self, validators: Validator | List[Validator] | None) -> None: + if isinstance(validators, list): + self.validators = validators + elif validators is not None: + self.validators = [validators] + else: + self.validators = [] - def __set__(self, instance, value): + def __set__(self, instance: Model, value: Any) -> None: self._finish_initialization(type(instance)) value = self.parse_value(value) self.validate(value) self.memory[instance._cache_key] = value - def __get__(self, instance, owner=None): + def __get__(self, instance: Model, owner: Model | None = None) -> Any: if instance is None: self._finish_initialization(owner) return self @@ -71,38 +72,38 @@ def __get__(self, instance, owner=None): self._check_value(instance) return self.memory[instance._cache_key] - def _finish_initialization(self, owner): + def _finish_initialization(self, owner: type[Model]) -> None: pass - def _check_value(self, obj): + def _check_value(self, obj: Model) -> None: if obj._cache_key not in self.memory: self.__set__(obj, self.get_default_value()) - def validate_for_object(self, obj): + def validate_for_object(self, obj: Model) -> None: value = self.__get__(obj) self.validate(value) - def validate(self, value): + def validate(self, value: Any) -> None: self._check_types() self._validate_against_types(value) self._check_against_required(value) self._validate_with_custom_validators(value) - def _check_against_required(self, value): + def _check_against_required(self, value: Any) -> None: if value is None and self.required: raise RequiredFieldError() - def _validate_against_types(self, value): + def _validate_against_types(self, value: Any) -> None: if value is not None and not isinstance(value, self.types): raise BadTypeError(value, self.types, is_list=False) - def _check_types(self): + def _check_types(self) -> None: if self.types is None: tpl = 'Field "{type}" is not usable, try different field type.' raise ValueError(tpl.format(type=type(self).__name__)) @staticmethod - def _get_embed_type(value, models): + def _get_embed_type(value: Value, models: tuple[EmbedType, ...]) -> EmbedType: """ Tries to guess which of the given models is applicable to the dict. :param value: The dict to check. @@ -120,7 +121,7 @@ def _get_embed_type(value, models): in model.iterate_with_name() } for model in models if hasattr(model, "iterate_with_name") - } # type: Dict[type, Set[str]] + } matching_models = [model for model, fields in model_fields.items() if fields.issuperset(value)] @@ -131,7 +132,7 @@ def _get_embed_type(value, models): return matching_models[0] return models[0] - def toBsonEncodable(self, value: types) -> BsonEncodable: + def toBsonEncodable(self, value: Any) -> BsonEncodable: """Optionally return a bson encodable python object. Returned object should be BSON compatible. By default uses the @@ -149,11 +150,11 @@ def toBsonEncodable(self, value: types) -> BsonEncodable: """ return self.to_struct(value=value) - def to_struct(self, value): - """Cast value to Python dict.""" - return value + def to_struct(self, value: Any) -> JSONValue: + """Cast value to Python structure.""" + return cast(JSONValue, value) - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse value from primitive to desired format. Each field can parse value to form it wants it to be (like string or @@ -162,17 +163,17 @@ def parse_value(self, value): """ return value - def _validate_with_custom_validators(self, value): + def _validate_with_custom_validators(self, value: Any) -> None: if value is None and self.nullable: return for validator in self.validators: try: - validator.validate(value) + cast(ValidatorObject, validator).validate(value) except AttributeError: - validator(value) + cast(ValidatorFunction, validator)(value) - def get_default_value(self): + def get_default_value(self) -> Any: """Get default value for field. Each field can specify its default. @@ -180,16 +181,16 @@ def get_default_value(self): """ return self._default if self.has_default else None - def _validate_name(self): + def _validate_name(self) -> None: if self.name is None: return if not re.match(r'^[A-Za-z_](([\w\-]*)?\w+)?$', self.name): raise ValueError('Wrong name', self.name) - def structure_name(self, default): + def structure_name(self, default: str) -> str: return self.name if self.name is not None else default - def structue_name(self, default): + def structue_name(self, default: str) -> str: warnings.warn("`structue_name` is deprecated, please use " "`structure_name`") return self.structure_name(default) @@ -199,16 +200,16 @@ class StringField(BaseField): """String field.""" - types = six.string_types + types: Tuple[Any, ...] = six.string_types class IntField(BaseField): """Integer field.""" - types = (int,) + types: Tuple[Any, ...] = (int,) - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Cast value to `int`, e.g. from string or long""" parsed = super(IntField, self).parse_value(value) if parsed is None: @@ -223,29 +224,34 @@ class FloatField(BaseField): """Float field.""" - types = (float, int) + types: Tuple[Any, ...] = (float, int) class BoolField(BaseField): """Bool field.""" - types = (bool,) + types: Tuple[Any, ...] = (bool,) - def parse_value(self, value): + def parse_value(self, value: Value) -> Any: """Cast value to `bool`.""" parsed = super(BoolField, self).parse_value(value) return bool(parsed) if parsed is not None else None +I = TypeVar("I") + + class ListField(BaseField): """List field.""" - types = (list, tuple) + types: Tuple[Any, ...] = (list, tuple) + items_types: tuple[EmbedType, ...] + item_validators: List[Any] - def __init__(self, items_types=None, item_validators=(), omit_empty=False, - *args, **kwargs): + def __init__(self, items_types: EmbedType | tuple[EmbedType, ...] | List[EmbedType] | None=None, item_validators: Union[Any, List[Any]]=[], + omit_empty: bool=False, *args: Any, **kwargs: Any): """Init. `ListField` is **always not required**. If you want to control number @@ -262,20 +268,20 @@ def __init__(self, items_types=None, item_validators=(), omit_empty=False, self.required = False self._omit_empty = omit_empty - def get_default_value(self): + def get_default_value(self) -> Any: default = super(ListField, self).get_default_value() if default is None: return ModelCollection(self) return default - def _assign_types(self, items_types): + def _assign_types(self, items_types: EmbedType | tuple[EmbedType, ...] | List[EmbedType] | None) -> None: if items_types: - try: + if isinstance(items_types, (tuple, list)): self.items_types = tuple(items_types) - except TypeError: - self.items_types = items_types, + else: + self.items_types = (items_types, ) else: - self.items_types = tuple() + self.items_types = () types = [] for type_ in self.items_types: @@ -285,13 +291,13 @@ def _assign_types(self, items_types): types.append(type_) self.items_types = tuple(types) - def validate(self, value): + def validate(self, value: Any) -> None: super(ListField, self).validate(value) for item in value: self.validate_single_value(item) - def validate_single_value(self, value): + def validate_single_value(self, value: Any) -> None: for validator in self.item_validators: try: validator.validate(value) @@ -302,9 +308,9 @@ def validate_single_value(self, value): return if not isinstance(value, self.items_types): - raise BadTypeError(value, self.items_types, is_list=True) + raise BadTypeError(value, tuple(self.items_types), is_list=True) - def parse_value(self, values): + def parse_value(self, values: Any) -> Any: """Cast value to proper collection.""" result = self.get_default_value() @@ -316,16 +322,16 @@ def parse_value(self, values): return [self._cast_value(value) for value in values] - def _cast_value(self, value): + def _cast_value(self, value: Any) -> Any: if isinstance(value, self.items_types): return value elif isinstance(value, dict): model_type = self._get_embed_type(value, self.items_types) return model_type(**value) else: - raise BadTypeError(value, self.items_types, is_list=True) + raise BadTypeError(value, tuple(self.items_types), is_list=True) - def _finish_initialization(self, owner): + def _finish_initialization(self, owner: type[Model]) -> None: super(ListField, self)._finish_initialization(owner) types = [] @@ -336,14 +342,14 @@ def _finish_initialization(self, owner): types.append(item_type) self.items_types = tuple(types) - def _elem_to_struct(self, value): + def _elem_to_struct(self, value: Value) -> Value | dict[str, Value]: try: return value.to_struct() except AttributeError: return value - def to_struct(self, values): - return [self._elem_to_struct(v) for v in values] \ + def to_struct(self, values: Any) -> JSONValue: + return [self._elem_to_struct(v) for v in cast(List, values)] \ if values or not self._omit_empty else None @@ -352,40 +358,52 @@ class DerivedListField(ListField): A list field that has another field for its items. """ - def __init__(self, field: BaseField, *args, **kwargs): + def __init__(self, field: BaseField | PrimitiveTypeInstance, *args: Any, **kwargs: Any): """ - :param field: The field that will be in each of the items of the list. + :param field: The field instance that will be in each of the items of the list. :param help_text: The help text of the list field. :param validators: The validators for the list field. """ + # Note: It is a bit of a hack but the signature allows many primitive + # types even though in reality we only accept BaseField instances. + # The extra types are for the type checker and our Mypy plugin. + if not isinstance(field, BaseField): + raise BadTypeError(field, (BaseField,), is_list=False) + self._field = field + + fixed_kwargs = kwargs.copy() + fixed_kwargs["items_types"] = field.types + fixed_kwargs["item_validators"] = field.validators super(DerivedListField, self).__init__( - items_types=field.types, - item_validators=field.validators, - *args, **kwargs, + *args, **fixed_kwargs, ) - def to_struct(self, values: List[any]) -> List[any]: + def to_struct(self, values: Any) -> JSONValue: """ Converts the list to its output format. :param values: The values in the list. :return: The converted values. """ - return [self._field.to_struct(value) for value in values] \ + return [self._field.to_struct(value) for value in cast(List, values)] \ if values or not self._omit_empty else None - def parse_value(self, values: List[any]) -> List[any]: + def parse_value(self, values: Any) -> Any: """ Converts the list to its internal format. :param values: The values in the list. :return: The converted values. """ + if values is None: + return None + try: return [self._field.parse_value(value) for value in values] except TypeError: raise BadTypeError(values, self._field.types, is_list=True) + return None - def validate_single_value(self, value: any) -> None: + def validate_single_value(self, value: Any) -> None: """ Validates a single value in the list. :param value: One of the values in the list. @@ -397,23 +415,23 @@ class EmbeddedField(BaseField): """Field for embedded models.""" - def __init__(self, model_types, *args, **kwargs): + def __init__(self, model_types: EmbedType | str | tuple[EmbedType | str, ...], *args: Any, **kwargs: Any) -> None: self._assign_model_types(model_types) super(EmbeddedField, self).__init__(*args, **kwargs) - def _assign_model_types(self, model_types): + def _assign_model_types(self, model_types: EmbedType | str | tuple[EmbedType | str, ...]) -> None: if not isinstance(model_types, (list, tuple)): model_types = (model_types,) - types = [] + types: List[EmbedType | _LazyType] = [] for type_ in model_types: if isinstance(type_, six.string_types): types.append(_LazyType(type_)) else: - types.append(type_) + types.append(cast(EmbedType, type_)) self.types = tuple(types) - def _finish_initialization(self, owner): + def _finish_initialization(self, owner: type[Model]) -> None: super(EmbeddedField, self)._finish_initialization(owner) types = [] for model_type in self.types: @@ -424,23 +442,23 @@ def _finish_initialization(self, owner): self.types = tuple(types) - def validate(self, value): + def validate(self, value: Any) -> None: super(EmbeddedField, self).validate(value) try: value.validate() except AttributeError: pass - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse value to proper model type.""" if not isinstance(value, dict): - return value + return cast(EmbedType, value) embed_type = self._get_embed_type(value, self.types) return embed_type(**value) - def to_struct(self, value): - return value.to_struct() + def to_struct(self, value: Any) -> JSONValue: + return cast(Model, value).to_struct() class MapField(BaseField): @@ -453,10 +471,10 @@ class MapField(BaseField): included in the to_struct method. """ - types = (dict,) + types: Tuple[Any, ...] = (dict,) - def __init__(self, key_field: BaseField, value_field: BaseField, - **kwargs): + def __init__(self, key_field: BaseField | PrimitiveTypeInstance, value_field: BaseField | PrimitiveTypeInstance, + **kwargs: Any): """ :param key_field: The field that is responsible for converting and validating the keys in this mapping. @@ -465,10 +483,19 @@ def __init__(self, key_field: BaseField, value_field: BaseField, :param kwargs: Other keyword arguments to the base class. """ super(MapField, self).__init__(**kwargs) + + # Note: It is a bit of a hack but the signature allows many primitive + # types even though in reality we only accept BaseField instances. + # The extra types are for the type checker and our Mypy plugin. + if not isinstance(key_field, BaseField): + raise BadTypeError(key_field, (BaseField,), is_list=False) self._key_field = key_field + + if not isinstance(value_field, BaseField): + raise BadTypeError(value_field, (BaseField,), is_list=False) self._value_field = value_field - def _finish_initialization(self, owner): + def _finish_initialization(self, owner: type[Model]) -> None: """ Completes the initialization of the fields, allowing for lazy refs. """ @@ -476,18 +503,18 @@ def _finish_initialization(self, owner): self._key_field._finish_initialization(owner) self._value_field._finish_initialization(owner) - def get_default_value(self) -> any: + def get_default_value(self) -> Any: """ Gets the default value for this field """ default = super(MapField, self).get_default_value() if default is None and self.required: return dict() return default - def parse_value(self, values: Optional[dict]) -> Optional[dict]: + def parse_value(self, values: Any) -> Any: """ Parses the given values into a new dict. """ values = super().parse_value(values) if values is None: - return + return None items = [ (self._key_field.parse_value(key), self._value_field.parse_value(value)) @@ -495,16 +522,16 @@ def parse_value(self, values: Optional[dict]) -> Optional[dict]: ] return type(values)(items) # Preserves OrderedDict - def to_struct(self, values: Optional[dict]) -> Optional[dict]: + def to_struct(self, values: Any) -> JSONValue: """ Casts the field values into a dict. """ items = [ (self._key_field.to_struct(key), self._value_field.to_struct(value)) - for key, value in values.items() + for key, value in cast(Dict, values).items() ] - return type(values)(items) # Preserves OrderedDict + return cast(JSONValue, type(values)(items)) # Preserves OrderedDict - def validate(self, values: Optional[dict]) -> Optional[dict]: + def validate(self, values: Any) -> None: """ Validates all keys and values in the map field. :param values: The values in the mapping. @@ -517,17 +544,16 @@ def validate(self, values: Optional[dict]) -> Optional[dict]: self._value_field.validate(value) -class _LazyType(object): - - def __init__(self, path): +class _LazyType: + def __init__(self, path: str) -> None: self.path = path - def evaluate(self, base_cls): + def evaluate(self, base_cls: type[Model]) -> Any: module, type_name = _evaluate_path(self.path, base_cls) return _import(module, type_name) -def _evaluate_path(relative_path, base_cls): +def _evaluate_path(relative_path: str, base_cls: type[Model]) -> tuple[Any, str]: base_module = base_cls.__module__ modules = _get_modules(relative_path, base_module) @@ -539,7 +565,7 @@ def _evaluate_path(relative_path, base_cls): return module, type_name -def _get_modules(relative_path, base_module): +def _get_modules(relative_path: str, base_module: str) -> Any: canonical_path = relative_path.lstrip('.') canonical_modules = canonical_path.split('.') @@ -554,7 +580,7 @@ def _get_modules(relative_path, base_module): return parent_modules[:parents_amount * -1] + canonical_modules -def _import(module_name, type_name): +def _import(module_name: str, type_name: str) -> Any: module = __import__(module_name, fromlist=[type_name]) try: return getattr(module, type_name) @@ -567,9 +593,11 @@ class TimeField(StringField): """Time field.""" - types = (datetime.time,) + types: Tuple[Any, ...] = (datetime.time,) - def __init__(self, str_format=None, *args, **kwargs): + def __init__( + self, str_format: Optional[str] = None, *args: Any, **kwargs: Any + ) -> None: """Init. :param str str_format: Format to cast time to (if `None` - casting to @@ -579,13 +607,14 @@ def __init__(self, str_format=None, *args, **kwargs): self.str_format = str_format super(TimeField, self).__init__(*args, **kwargs) - def to_struct(self, value): + def to_struct(self, value: Any) -> JSONValue: """Cast `time` object to string.""" + datetime_value = cast(datetime.time, value) if self.str_format: - return value.strftime(self.str_format) - return value.isoformat() + return datetime_value.strftime(self.str_format) + return datetime_value.isoformat() - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse string into instance of `time`.""" if value is None: return value @@ -598,10 +627,12 @@ class DateField(StringField): """Date field.""" - types = (datetime.date,) + types: Tuple[Any, ...] = (datetime.date,) default_format = '%Y-%m-%d' - def __init__(self, str_format=None, *args, **kwargs): + def __init__( + self, str_format: Optional[str] = None, *args: Any, **kwargs: Any + ) -> None: """Init. :param str str_format: Format to cast date to (if `None` - casting to @@ -611,13 +642,14 @@ def __init__(self, str_format=None, *args, **kwargs): self.str_format = str_format super(DateField, self).__init__(*args, **kwargs) - def to_struct(self, value): + def to_struct(self, value: Any) -> JSONValue: """Cast `date` object to string.""" + date_value = cast(datetime.date, value) if self.str_format: - return value.strftime(self.str_format) - return value.strftime(self.default_format) + return date_value.strftime(self.str_format) + return date_value.strftime(self.default_format) - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse string into instance of `date`.""" if value is None: return value @@ -630,9 +662,11 @@ class DateTimeField(StringField): """Datetime field.""" - types = (datetime.datetime,) + types: Tuple[Any, ...] = (datetime.datetime,) - def __init__(self, str_format=None, *args, **kwargs): + def __init__( + self, str_format: Optional[str] = None, *args: Any, **kwargs: Any + ) -> None: """Init. :param str str_format: Format to cast datetime to (if `None` - casting @@ -642,21 +676,22 @@ def __init__(self, str_format=None, *args, **kwargs): self.str_format = str_format super(DateTimeField, self).__init__(*args, **kwargs) - def to_struct(self, value): + def to_struct(self, value: Any) -> JSONValue: """Cast `datetime` object to string.""" + datetime_value = cast(datetime.datetime, value) if self.str_format: - return value.strftime(self.str_format) - return value.isoformat() + return datetime_value.strftime(self.str_format) + return datetime_value.isoformat() - def toBsonEncodable(self, value: datetime) -> datetime: + def toBsonEncodable(self, value: Any) -> BsonEncodable: """ Keep datetime object a datetime object, since pymongo supports that. """ if not isinstance(value, self.types): raise BadTypeError(value, self.types, is_list=False) - return value + return cast(BsonEncodable, value) - def parse_value(self, value): + def parse_value(self, value: Any) -> Any: """Parse string into instance of `datetime`.""" if isinstance(value, datetime.datetime): return value @@ -671,12 +706,12 @@ class GenericField(BaseField): Field that supports any kind of value, converting models to their correct struct, keeping ordered dictionaries in their original order. """ - types = (any,) + types: Tuple[Any, ...] = (any,) - def _validate_against_types(self, value) -> None: + def _validate_against_types(self, value: Value) -> None: pass - def to_struct(self, values: any) -> any: + def to_struct(self, values: Any) -> JSONValue: """ Casts value to Python structure. """ from .models import Base if isinstance(values, Base): @@ -690,4 +725,4 @@ def to_struct(self, values: any) -> any: for key, value in values.items()] return type(values)(items) # preserves OrderedDict - return values + return cast(JSONValue, values) diff --git a/jsonmodels/models.py b/jsonmodels/models.py index 350393e..3d1ebe4 100644 --- a/jsonmodels/models.py +++ b/jsonmodels/models.py @@ -1,18 +1,24 @@ -import six +from typing import Any, Dict, Generator, Tuple, Type, cast + +from jsonmodels.types import JSONSchemaProperty from . import parsers, errors from .fields import BaseField from .errors import FieldValidationError, ValidatorError, ValidationError +from .types import Field, JSONSchemaProperty, JSONValue +Values = Dict[str, Any] +Fields = Tuple[str, Field] +FieldsWithNames = Tuple[str, str, Field] -class JsonmodelMeta(type): - def __new__(cls, name, bases, attributes): +class JsonmodelMeta(type): + def __new__(cls: Type[JsonmodelMeta], name: str, bases: tuple, attributes: dict) -> type: cls.validate_fields(attributes) return super(cls, cls).__new__(cls, name, bases, attributes) @staticmethod - def validate_fields(attributes): + def validate_fields(attributes: dict[str, Any]) -> None: fields = { key: value for key, value in attributes.items() if isinstance(value, BaseField) @@ -25,15 +31,15 @@ def validate_fields(attributes): taken_names.add(structure_name) -class Base(six.with_metaclass(JsonmodelMeta, object)): +class Base(metaclass=JsonmodelMeta): """Base class for all models.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Values) -> None: self._cache_key = _CacheKey() self.populate(**kwargs) - def populate(self, **values): + def populate(self, **values: Values) -> None: """Populate values to fields. Skip non-existing.""" values = values.copy() fields = list(self.iterate_with_name()) @@ -45,7 +51,7 @@ def populate(self, **values): if name in values: self.set_field(field, name, values.pop(name)) - def get_field(self, field_name): + def get_field(self, field_name: str) -> Field: """Get field associated with given attribute.""" for attr_name, field in self: if field_name == attr_name: @@ -53,7 +59,7 @@ def get_field(self, field_name): raise errors.FieldNotFound(field_name) - def set_field(self, field, field_name, value): + def set_field(self, field: Field, field_name: str, value: Any) -> None: """ Sets the value of a field. """ try: field.__set__(self, value) @@ -61,12 +67,12 @@ def set_field(self, field, field_name, value): raise FieldValidationError(type(self).__name__, field_name, value, error) - def __iter__(self): + def __iter__(self) -> Generator[Fields, None, None]: """Iterate through fields and values.""" for name, field in self.iterate_over_fields(): yield name, field - def validate(self): + def validate(self) -> None: """Explicitly validate all the fields.""" for name, field in self: try: @@ -77,15 +83,15 @@ def validate(self): value, error) @classmethod - def iterate_over_fields(cls): + def iterate_over_fields(cls) -> Generator[Fields, None, None]: """Iterate through fields as `(attribute_name, field_instance)`.""" for attr in dir(cls): class_attribute = getattr(cls, attr) if isinstance(class_attribute, BaseField): - yield attr, class_attribute + yield attr, cast(Field, class_attribute) @classmethod - def iterate_with_name(cls): + def iterate_with_name(cls) -> Generator[FieldsWithNames, None, None]: """Iterate over fields, but also give `structure_name`. Format is `(attribute_name, structure_name, field_instance)`. @@ -96,16 +102,16 @@ def iterate_with_name(cls): structure_name = field.structure_name(attr_name) yield attr_name, structure_name, field - def to_struct(self): + def to_struct(self) -> JSONValue: """Cast model to Python structure.""" return parsers.to_struct(self) @classmethod - def to_json_schema(cls): + def to_json_schema(cls) -> JSONSchemaProperty: """Generate JSON schema for model.""" return parsers.to_json_schema(cls) - def __repr__(self): + def __repr__(self) -> str: attrs = {} for name, _ in self: try: @@ -122,17 +128,17 @@ def __repr__(self): ), ) - def __str__(self): + def __str__(self) -> str: return '{name} object'.format(name=self.__class__.__name__) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: try: return super(Base, self).__setattr__(name, value) except ValidatorError as error: raise FieldValidationError(type(self).__name__, name, value, error) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if type(other) is not type(self): return False @@ -152,9 +158,9 @@ def __eq__(self, other): return True - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not (self == other) -class _CacheKey(object): +class _CacheKey: """Object to identify model in memory.""" diff --git a/jsonmodels/mypy_plugin.py b/jsonmodels/mypy_plugin.py new file mode 100644 index 0000000..14ad722 --- /dev/null +++ b/jsonmodels/mypy_plugin.py @@ -0,0 +1,162 @@ +from typing import Callable, List, Type +import mypy +from mypy.plugin import Plugin, AttributeContext, FunctionContext +from mypy.types import Type as MypyType +from mypy.nodes import TypeInfo + + +JSONMODEL_TYPE = [ + "jsonmodels.fields.StringField", + "jsonmodels.fields.IntField", + "jsonmodels.fields.FloatField", + "jsonmodels.fields.BoolField", + "jsonmodels.fields.TimeField", + "jsonmodels.fields.DateField", + "jsonmodels.fields.DateTimeField", + "jsonmodels.fields.EmbeddedField", + "jsonmodels.fields.ListField", + "jsonmodels.fields.DerivedListField", + "jsonmodels.fields.MapField", + "jsonmodels.fields.GenericField" +] + + +class JSONModelsPlugin(Plugin): + def get_function_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: + jsonmodel_fullname: str + + sym = self.lookup_fully_qualified(fullname) + if sym is None: + return None + node = sym.node + if not isinstance(node, TypeInfo): + return None + + # Find a known jsonmodel field type in this type's class hierarchy. + for node in node.mro: + if node.fullname in JSONMODEL_TYPE: + jsonmodel_fullname = node.fullname + break + else: + return None + + if jsonmodel_fullname == "jsonmodels.fields.StringField": + return self._string_field_callback + if jsonmodel_fullname == "jsonmodels.fields.IntField": + return self._int_field_callback + if jsonmodel_fullname == "jsonmodels.fields.FloatField": + return self._float_field_callback + if jsonmodel_fullname == "jsonmodels.fields.BoolField": + return self._bool_field_callback + if jsonmodel_fullname == "jsonmodels.fields.TimeField": + return self._time_field_callback + if jsonmodel_fullname == "jsonmodels.fields.DateField": + return self._date_field_callback + if jsonmodel_fullname == "jsonmodels.fields.DateTimeField": + return self._datetime_field_callback + if jsonmodel_fullname == "jsonmodels.fields.EmbeddedField": + return self._embedded_field_callback + if jsonmodel_fullname == "jsonmodels.fields.ListField": + return self._list_field_callback + if jsonmodel_fullname == "jsonmodels.fields.DerivedListField": + return self._derived_list_field_callback + if jsonmodel_fullname == "jsonmodels.fields.MapField": + return self._map_field_callback + if jsonmodel_fullname == "jsonmodels.fields.GenericField": + return self._generic_field_callback + + return None + + def _wrap_nullable(self, ctx: FunctionContext, core_type: MypyType) -> MypyType: + try: + nullable_index = ctx.callee_arg_names.index("nullable") + except ValueError: + return core_type + + arg_value = ctx.args[nullable_index] + if len(arg_value) == 0: + return core_type + + nullable_value = arg_value[0] + if isinstance(nullable_value, mypy.nodes.NameExpr) and nullable_value.fullname == "builtins.True": + return mypy.types.UnionType([core_type, mypy.types.NoneType()]) + + return core_type + + def _string_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.str")) + + def _int_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.int")) + + def _float_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.float")) + + def _bool_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("builtins.bool")) + + def _time_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("datetime.time")) + + def _date_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("datetime.date")) + + def _datetime_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, ctx.api.named_type("datetime.datetime")) + + def _get_type_from_arg(self, ctx: FunctionContext, arg_name: str) -> MypyType: + try: + model_types_index = ctx.callee_arg_names.index(arg_name) + except ValueError: + return mypy.types.NoneType() + + arg_value = ctx.args[model_types_index] + if len(arg_value) == 0: + return mypy.types.NoneType() + + model_types_value = arg_value[0] + + if isinstance(model_types_value, mypy.nodes.NameExpr): + return ctx.api.named_type(model_types_value.fullname) + + if isinstance(model_types_value, mypy.nodes.TupleExpr): + accepted_types: List[MypyType] = [] + for item in model_types_value.items: + if isinstance(item, mypy.nodes.NameExpr): + accepted_types.append(ctx.api.named_type(item.fullname)) + return mypy.types.UnionType(accepted_types) + + return mypy.types.NoneType() + + def _embedded_field_callback(self, ctx: FunctionContext) -> MypyType: + return self._wrap_nullable(ctx, self._get_type_from_arg(ctx, "model_types")) + + def _list_field_callback(self, ctx: FunctionContext) -> MypyType: + item_type = self._get_type_from_arg(ctx, "items_types") + list_type = ctx.api.named_generic_type("list", [item_type]) + return self._wrap_nullable(ctx, list_type) + + def _get_type_from_arg_type(self, ctx: FunctionContext, arg_name: str) -> MypyType: + try: + model_types_index = ctx.callee_arg_names.index(arg_name) + except ValueError: + return mypy.types.NoneType() + + return ctx.arg_types[model_types_index][0] + + def _derived_list_field_callback(self, ctx: FunctionContext) -> MypyType: + item_type = self._get_type_from_arg_type(ctx, "field") + list_type = ctx.api.named_generic_type("list", [item_type]) + return self._wrap_nullable(ctx, list_type) + + def _map_field_callback(self, ctx: FunctionContext) -> MypyType: + key_type = self._get_type_from_arg_type(ctx, "key_field") + value_type = self._get_type_from_arg_type(ctx, "value_field") + list_type = ctx.api.named_generic_type("dict", [key_type, value_type]) + return self._wrap_nullable(ctx, list_type) + + def _generic_field_callback(self, ctx: FunctionContext) -> MypyType: + return mypy.types.AnyType(mypy.types.TypeOfAny.special_form) + +def plugin(version: str): + return JSONModelsPlugin diff --git a/jsonmodels/parsers.py b/jsonmodels/parsers.py index 3f0f60d..62f8050 100644 --- a/jsonmodels/parsers.py +++ b/jsonmodels/parsers.py @@ -1,16 +1,13 @@ """Parsers to change model structure into different ones.""" import inspect +from typing import Any, cast -from . import fields, builders, errors +from . import builders, errors, fields +from .types import Builder, Field, JSONSchemaProperty, JSONSchemaTypeName, JSONValue, Model -def to_struct(model): - """ - Cast instance of model to python structure. - :param model: Model to be casted. - :rtype: ``dict`` - - """ +def to_struct(model: Model) -> JSONValue: + """Cast instance of model to python structure.""" model.validate() resp = {} @@ -25,18 +22,13 @@ def to_struct(model): return resp -def to_json_schema(cls): - """Generate JSON schema for given class. - - :param cls: Class to be casted. - :rtype: ``dict`` - - """ +def to_json_schema(cls: Any) -> JSONSchemaProperty: + """Generate JSON schema for given class.""" builder = build_json_schema(cls) - return builder.build() + return cast(JSONSchemaProperty, builder.build()) -def build_json_schema(value, parent_builder=None): +def build_json_schema(value: Any, parent_builder: Builder | None = None) -> Builder: from .models import Base cls = value if inspect.isclass(value) else value.__class__ @@ -46,7 +38,7 @@ def build_json_schema(value, parent_builder=None): return build_json_schema_primitive(cls, parent_builder) -def build_json_schema_object(cls, parent_builder=None): +def build_json_schema_object(cls: type[Model], parent_builder: Builder | None = None) -> builders.ObjectBuilder: builder = builders.ObjectBuilder(cls, parent_builder) if builder.count_type(builder.type) > 1: return builder @@ -56,12 +48,11 @@ def build_json_schema_object(cls, parent_builder=None): elif isinstance(field, fields.ListField): builder.add_field(name, field, _parse_list(field, builder)) else: - builder.add_field( - name, field, _create_primitive_field_schema(field)) + builder.add_field(name, field, _create_primitive_field_schema(field)) return builder -def _parse_list(field, parent_builder): +def _parse_list(field: fields.ListField, parent_builder: Builder | None) -> str | JSONSchemaProperty: builder = builders.ListBuilder( parent_builder, field.nullable, default=field._default) for type in field.items_types: @@ -69,7 +60,7 @@ def _parse_list(field, parent_builder): return builder.build() -def _parse_embedded(field, parent_builder): +def _parse_embedded(field: fields.EmbeddedField, parent_builder: Builder | None) -> str | JSONSchemaProperty: builder = builders.EmbeddedBuilder( parent_builder, field.nullable, default=field._default) for type in field.types: @@ -77,13 +68,13 @@ def _parse_embedded(field, parent_builder): return builder.build() -def build_json_schema_primitive(cls, parent_builder): +def build_json_schema_primitive(cls: type, parent_builder: Builder | None) -> Builder: builder = builders.PrimitiveBuilder(cls, parent_builder) return builder -def _create_primitive_field_schema(field): - schema = {'type': _get_schema_type(field)} +def _create_primitive_field_schema(field: Field) -> JSONSchemaProperty: + schema: JSONSchemaProperty = {'type': _get_schema_type(field)} if isinstance(field, fields.FloatField): schema['format'] = 'float' @@ -98,7 +89,8 @@ def _create_primitive_field_schema(field): return schema -def _get_schema_type(field): +def _get_schema_type(field: Field) -> JSONSchemaTypeName: + obj_type: JSONSchemaTypeName if isinstance(field, fields.StringField): obj_type = 'string' elif isinstance(field, fields.IntField): @@ -112,5 +104,5 @@ def _get_schema_type(field): else: raise errors.FieldNotSupported(type(field)) if field.nullable: - obj_type = [obj_type, 'null'] + return [obj_type, 'null'] return obj_type diff --git a/jsonmodels/types.py b/jsonmodels/types.py new file mode 100644 index 0000000..a225878 --- /dev/null +++ b/jsonmodels/types.py @@ -0,0 +1,156 @@ +import datetime +from re import Pattern +from typing import Any, Callable, Dict, Generator, List, Literal, Protocol, Tuple, TypedDict, Union, runtime_checkable +from weakref import WeakKeyDictionary + +from jsonmodels.models import _CacheKey + +Value = Any + +JSONObject = Dict[str, "JSONValue"] +JSONValue = Union[None, bool, str, float, int, List["JSONValue"], JSONObject] + +JSONSchemaBasicTypeName = Literal["string"] | Literal["number"] | Literal["boolean"] | Literal["object"] | Literal["array"] | Literal["null"] +JSONSchemaTypeName = JSONSchemaBasicTypeName | List[JSONSchemaBasicTypeName | Literal["null"]] + +class JSONSchemaProperty(TypedDict, total=False): + type: JSONSchemaTypeName + format: str + default: Any + required: List[str] + items: JSONSchemaProperty + description: str + properties: Dict[str, str | JSONSchemaProperty] + additionalProperties: bool + definitions: Dict[str, JSONSchemaProperty] + + minItems: int + minLength: int + maxItems: int + maxLength: int + + minimum: int | float + exclusiveMinimum: bool + + maximum: int | float + exclusiveMaximum: bool + + pattern: str + enum: List[str] + +# JSONSchema = JSONSchemaProperty + +# BSON compatible types, which can be returned by toBsonEncodable. +BsonEncodable = Union[ + float, str, object, Dict, List, bytes, bool, datetime.datetime, None, + Pattern, int, bytes +] + +@runtime_checkable +class Builder(Protocol): + def register_type(self, model_type: type[Model], builder: "Builder") -> None: + ... + + def get_builder(self, model_type: type[Model]) -> "Builder": + ... + + def count_type(self, model_type: type[Model]) -> int: + ... + + def build(self) -> str | JSONSchemaProperty: + ... + + def add_definition(self, builder: "Builder") -> None: + ... + + def build_definition(self, add_definitions: bool = True) -> JSONSchemaProperty: + ... + + @property + def is_definition(self) -> bool: + ... + + @property + def type_name(self) -> str: + ... + + +@runtime_checkable +class ValidatorObject(Protocol): + + def validate(self, value: Any) -> None: + ... + + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: + ... + +ValidatorFunction = Callable[[Any], None] +Validator = ValidatorFunction | ValidatorObject + +class Field(Protocol): + types: Tuple[Any, ...] + memory: WeakKeyDictionary + required: bool + validators: List[Validator] + item_validators: List[Validator] + help_text: str | None + nullable: bool + _default: Any + + @property + def has_default(self) -> bool: + ... + + def __set__(self, instance: Model, value: Any) -> None: + ... + + def __get__(self, instance: Model) -> Any: + ... + + def _finish_initialization(self, owner: type[Model]) -> None: + ... + + def to_struct(self, value: Any) -> JSONValue: + ... + + def structure_name(self, default: str) -> str: + ... + + def toBsonEncodable(self, value: Any) -> BsonEncodable: + ... + + def validate_for_object(self, obj: Model) -> None: + ... + + def parse_value(self, value: Value) -> Any: + ... + + def validate(self, value: Value) -> None: + ... + +class CollectionField(Field, Protocol): + def validate_single_value(self, value: Value) -> None: + ... + + +Fields = Tuple[str, Field] +FieldsWithNames = Tuple[str, str, Field] + +class Model(Protocol): + + # __name__: str + _cache_key: _CacheKey + + def validate(self) -> None: + ... + + @classmethod + def iterate_with_name(cls) -> Generator[FieldsWithNames, None, None]: + ... + + def to_struct(self) -> JSONValue: + ... + +EmbedType = Union[type[str], type[int], type[float], type[bool], type[list], type[dict], type[Model]] + +PrimitiveTypeInstance = str | int | float | bool diff --git a/jsonmodels/utilities.py b/jsonmodels/utilities.py index 5e864ed..51c576e 100644 --- a/jsonmodels/utilities.py +++ b/jsonmodels/utilities.py @@ -5,8 +5,12 @@ import six import re from collections import namedtuple +from typing import Dict, cast, Any, List, Tuple -SCALAR_TYPES = tuple(list(six.string_types) + [int, float, bool]) +from jsonmodels.types import JSONSchemaProperty + +six_string_types: List[Any] = list(six.string_types) +SCALAR_TYPES = cast(Tuple[Any], tuple(six_string_types + [int, float, bool])) ECMA_TO_PYTHON_FLAGS = { 'i': re.I, @@ -20,14 +24,14 @@ PythonRegex = namedtuple('PythonRegex', ['regex', 'flags']) -def _normalize_string_type(value): +def _normalize_string_type(value: Any) -> Any: if isinstance(value, six.string_types): return six.text_type(value) else: return value -def _compare_dicts(one, two): +def _compare_dicts(one: Dict, two: Dict) -> bool: if len(one) != len(two): return False @@ -40,7 +44,7 @@ def _compare_dicts(one, two): return True -def _compare_lists(one, two): +def _compare_lists(one: List, two: List) -> bool: if len(one) != len(two): return False @@ -54,13 +58,13 @@ def _compare_lists(one, two): return they_match -def _assert_same_types(one, two): +def _assert_same_types(one: Any, two: Any) -> None: if not isinstance(one, type(two)) or not isinstance(two, type(one)): raise RuntimeError('Types mismatch! "{type1}" and "{type2}".'.format( type1=type(one).__name__, type2=type(two).__name__)) -def compare_schemas(one, two): +def compare_schemas(one: JSONSchemaProperty, two: JSONSchemaProperty) -> bool: """Compare two structures that represents JSON schemas. For comparison you can't use normal comparison, because in JSON schema @@ -83,7 +87,7 @@ def compare_schemas(one, two): if isinstance(one, list): return _compare_lists(one, two) elif isinstance(one, dict): - return _compare_dicts(one, two) + return _compare_dicts(cast(dict, one), cast(dict, two)) elif isinstance(one, SCALAR_TYPES): return one == two elif one is None: @@ -93,7 +97,7 @@ def compare_schemas(one, two): type=type(one).__name__)) -def is_ecma_regex(regex): +def is_ecma_regex(regex: str) -> bool: """Check if given regex is of type ECMA 262 or not. :rtype: bool @@ -110,7 +114,7 @@ def is_ecma_regex(regex): return False -def convert_ecma_regex_to_python(value): +def convert_ecma_regex_to_python(value: str) -> PythonRegex: """Convert ECMA 262 regex to Python tuple with regex and flags. If given value is already Python regex it will be returned unchanged. @@ -134,7 +138,7 @@ def convert_ecma_regex_to_python(value): return PythonRegex('/'.join(parts[1:]), result_flags) -def convert_python_regex_to_ecma(value, flags=()): +def convert_python_regex_to_ecma(value: str, flags: List[re.RegexFlag]=[]) -> str: """Convert Python regex to ECMA 262 regex. If given value is already ECMA regex it will be returned unchanged. @@ -149,6 +153,6 @@ def convert_python_regex_to_ecma(value, flags=()): return value result_flags = [PYTHON_TO_ECMA_FLAGS[f] for f in flags] - result_flags = ''.join(result_flags) + result_flags_str = ''.join(result_flags) - return '/{value}/{flags}'.format(value=value, flags=result_flags) + return '/{value}/{flags}'.format(value=value, flags=result_flags_str) diff --git a/jsonmodels/validators.py b/jsonmodels/validators.py index 480141a..3909a6c 100644 --- a/jsonmodels/validators.py +++ b/jsonmodels/validators.py @@ -1,18 +1,19 @@ """Predefined validators.""" import re +from typing import Sized from six.moves import reduce from .errors import MinValidationError, MaxValidationError, BadTypeError, \ RegexError, MinLengthError, MaxLengthError, EnumError from . import utilities - +from .types import JSONSchemaProperty class Min(object): """Validator for minimum value.""" - def __init__(self, minimum_value, exclusive=False): + def __init__(self, minimum_value: int | float, exclusive: bool = False) -> None: """Init. :param minimum_value: Minimum value for validator. @@ -23,13 +24,13 @@ def __init__(self, minimum_value, exclusive=False): self.minimum_value = minimum_value self.exclusive = exclusive - def validate(self, value): + def validate(self, value: int | float) -> None: """Validate value.""" if value < self.minimum_value \ or (self.exclusive and value == self.minimum_value): raise MinValidationError(value, self.minimum_value, self.exclusive) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" field_schema['minimum'] = self.minimum_value if self.exclusive: @@ -40,7 +41,7 @@ class Max(object): """Validator for maximum value.""" - def __init__(self, maximum_value, exclusive=False): + def __init__(self, maximum_value: int | float, exclusive: bool = False) -> None: """Init. :param maximum_value: Maximum value for validator. @@ -51,13 +52,13 @@ def __init__(self, maximum_value, exclusive=False): self.maximum_value = maximum_value self.exclusive = exclusive - def validate(self, value): + def validate(self, value: int | float) -> None: """Validate value.""" if value > self.maximum_value \ or (self.exclusive and value == self.maximum_value): raise MaxValidationError(value, self.maximum_value, self.exclusive) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" field_schema['maximum'] = self.maximum_value if self.exclusive: @@ -73,7 +74,7 @@ class Regex(object): 'multiline': re.M, } - def __init__(self, pattern, custom_error=None, **flags): + def __init__(self, pattern: str, custom_error: Exception | None=None, **flags: re._FlagsType) -> None: """Init. Note, that if given pattern is ECMA regex, given flags will be @@ -98,7 +99,7 @@ def __init__(self, pattern, custom_error=None, **flags): self.flags = [self.FLAGS[key] for key, value in flags.items() if key in self.FLAGS and value] - def validate(self, value): + def validate(self, value: str) -> None: """Validate value.""" flags = self._calculate_flags() @@ -112,10 +113,10 @@ def validate(self, value): raise self.custom_error raise RegexError(value, self.pattern) - def _calculate_flags(self): + def _calculate_flags(self) -> re._FlagsType: return reduce(lambda x, y: x | y, self.flags, 0) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" field_schema['pattern'] = utilities.convert_python_regex_to_ecma( self.pattern, self.flags @@ -126,7 +127,7 @@ class Length(object): """Validator for length.""" - def __init__(self, minimum_value=None, maximum_value=None): + def __init__(self, minimum_value: int | None = None, maximum_value: int | None = None) -> None: """Init. Note that if no `minimum_value` neither `maximum_value` will be @@ -144,7 +145,7 @@ def __init__(self, minimum_value=None, maximum_value=None): self.minimum_value = minimum_value self.maximum_value = maximum_value - def validate(self, value): + def validate(self, value: Sized) -> None: """Validate value.""" len_ = len(value) @@ -154,24 +155,28 @@ def validate(self, value): if self.maximum_value is not None and len_ > self.maximum_value: raise MaxLengthError(value, self.maximum_value) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: """Modify field schema.""" is_array = field_schema.get('type') == 'array' if self.minimum_value: - key = 'minItems' if is_array else 'minLength' - field_schema[key] = self.minimum_value + if is_array: + field_schema['minItems'] = self.minimum_value + else: + field_schema['minLength'] = self.minimum_value if self.maximum_value: - key = 'maxItems' if is_array else 'maxLength' - field_schema[key] = self.maximum_value + if is_array: + field_schema['maxItems'] = self.maximum_value + else: + field_schema['maxLength'] = self.maximum_value class Enum(object): """Validator for enums.""" - def __init__(self, *choices): + def __init__(self, *choices: str) -> None: """Init. :param [] choices: Valid choices for the field. @@ -179,9 +184,9 @@ def __init__(self, *choices): self.choices = list(choices) - def validate(self, value): + def validate(self, value: str) -> None: if value not in self.choices: raise EnumError(value, self.choices) - def modify_schema(self, field_schema): + def modify_schema(self, field_schema: JSONSchemaProperty) -> None: field_schema['enum'] = self.choices diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..0fa0de0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +disallow_untyped_defs = True +disallow_any_unimported = True +no_implicit_optional = True +warn_return_any = True +warn_unused_configs = True +warn_unused_ignores = True +show_error_codes = True diff --git a/mypy_plugin.ini b/mypy_plugin.ini new file mode 100644 index 0000000..3c0a5de --- /dev/null +++ b/mypy_plugin.ini @@ -0,0 +1,9 @@ +[mypy] +disallow_untyped_defs = True +disallow_any_unimported = True +no_implicit_optional = True +warn_return_any = True +warn_unused_configs = True +warn_unused_ignores = True +show_error_codes = True +plugins = jsonmodels/mypy_plugin.py diff --git a/requirements.txt b/requirements.txt index 9925baf..c7b171b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ py pyflakes pytest pytest-cov +python-dateutil sphinxcontrib-spelling tox virtualenv diff --git a/run_mypy.sh b/run_mypy.sh new file mode 100755 index 0000000..61b9d3c --- /dev/null +++ b/run_mypy.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +mypy -p jsonmodels diff --git a/tests/__init__.py b/tests/__init__.py index 14df325..398ac48 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,8 @@ def _have_flake8(): try: - import flake8 # noqa: F401 + MYPY = False + if not MYPY: + import flake8 # noqa: F401 return True except ImportError: return False diff --git a/tests/test_fields.py b/tests/test_fields.py index dc740d0..438b31d 100755 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,8 +1,9 @@ from collections import OrderedDict import datetime from datetime import timezone +from typing import Dict, Union -import pytest +import pytest # type: ignore from jsonmodels import models, fields, validators, errors @@ -14,7 +15,7 @@ def test_deprecated_structue_name(): assert field.structue_name('default') == 'default' -def test_bool_field(): +def test_bool_field() -> None: field = fields.BoolField() @@ -43,7 +44,7 @@ class Person(models.Base): assert field.parse_value([]) is False -def test_datetime_field(): +def test_datetime_field() -> None: field = fields.DateTimeField() class Event(models.Base): @@ -59,7 +60,7 @@ class Event(models.Base): datetime.datetime(2019, 10, 30, 1, 2, 3, tzinfo=timezone.utc) -def test_custom_field(): +def test_custom_field() -> None: class NameField(fields.StringField): def __init__(self): super(NameField, self).__init__(required=True) @@ -75,7 +76,7 @@ class Person(models.Base): assert person.to_struct() == expected -def test_custom_field_validation(): +def test_custom_field_validation() -> None: class NameField(fields.StringField): def __init__(self): super(NameField, self).__init__( @@ -102,7 +103,7 @@ class Person(models.Base): person.validate() -def test_map_field(): +def test_map_field() -> None: class Model(models.Base): str_to_int = fields.MapField(fields.StringField(), fields.IntField()) int_to_str = fields.MapField(fields.IntField(), fields.StringField()) @@ -131,13 +132,13 @@ class CircularMapModel(models.Base): ) -def test_map_field_circular(): +def test_map_field_circular() -> None: model = CircularMapModel(mapping={1: {}, 2: CircularMapModel()}) - expected = {'mapping': {1: {}, 2: {}}} + expected: Dict[str, Dict[int, Dict]] = {'mapping': {1: {}, 2: {}}} assert expected == model.to_struct() -def test_map_field_validation(): +def test_map_field_validation() -> None: class Model(models.Base): str_to_int = fields.MapField(fields.StringField(), fields.IntField()) int_to_str = fields.MapField(fields.IntField(), fields.StringField(), @@ -162,7 +163,7 @@ class Model(models.Base): model.validate() -def test_generic_field(): +def test_generic_field() -> None: class Model(models.Base): field = fields.GenericField() @@ -178,7 +179,7 @@ class Model(models.Base): assert expected == model_ordered.to_struct() -def test_derived_list_omit_empty(): +def test_derived_list_omit_empty() -> None: class Car(models.Base): wheels = fields.DerivedListField(fields.StringField(), @@ -190,7 +191,7 @@ class Car(models.Base): assert viper.to_struct() == {"doors": []} -def test_automatic_model_detection(): +def test_automatic_model_detection() -> None: class FullName(models.Base): first_name = fields.StringField() diff --git a/tests/test_jsonmodels.py b/tests/test_jsonmodels.py index 3e0e162..db07251 100644 --- a/tests/test_jsonmodels.py +++ b/tests/test_jsonmodels.py @@ -54,10 +54,10 @@ class Person(models.Base): alan = Person() - with pytest.raises(ValueError): + with pytest.raises(errors.FieldValidationError): alan.name = 'some name' - with pytest.raises(ValueError): + with pytest.raises(errors.FieldValidationError): alan.name = 2345 diff --git a/tests_mypy/__init__.py b/tests_mypy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests_mypy/case_alias.py b/tests_mypy/case_alias.py new file mode 100644 index 0000000..f6a6cf6 --- /dev/null +++ b/tests_mypy/case_alias.py @@ -0,0 +1,11 @@ +from jsonmodels import models, fields + +first_name_field = fields.StringField + +class AliasModel(models.Base): + name = first_name_field() + +alias = AliasModel() + +reveal_type(alias.name) +# expect: builtins.str diff --git a/tests_mypy/case_date.py b/tests_mypy/case_date.py new file mode 100644 index 0000000..aab2422 --- /dev/null +++ b/tests_mypy/case_date.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.dob) +# expect: datetime.date diff --git a/tests_mypy/case_datetime.py b/tests_mypy/case_datetime.py new file mode 100644 index 0000000..27edaa2 --- /dev/null +++ b/tests_mypy/case_datetime.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.last_update) +# expect: datetime.datetime diff --git a/tests_mypy/case_derivedlist.py b/tests_mypy/case_derivedlist.py new file mode 100644 index 0000000..fefda8d --- /dev/null +++ b/tests_mypy/case_derivedlist.py @@ -0,0 +1,7 @@ +from models import person + +reveal_type(person.nicknames) +# expect: builtins.list[builtins.str] + +reveal_type(person.alias_names) +# expect: builtins.list[builtins.str] diff --git a/tests_mypy/case_embedded.py b/tests_mypy/case_embedded.py new file mode 100644 index 0000000..26dc711 --- /dev/null +++ b/tests_mypy/case_embedded.py @@ -0,0 +1,7 @@ +from models import person + +reveal_type(person.address) +# expect: models.Address + +reveal_type(person.transport) +# expect: Union[models.Car, models.Boat] diff --git a/tests_mypy/case_generic.py b/tests_mypy/case_generic.py new file mode 100644 index 0000000..46c1996 --- /dev/null +++ b/tests_mypy/case_generic.py @@ -0,0 +1,4 @@ +from models import car_registry + +reveal_type(car_registry.any_random) +# expect: Any diff --git a/tests_mypy/case_int.py b/tests_mypy/case_int.py new file mode 100644 index 0000000..4e8b370 --- /dev/null +++ b/tests_mypy/case_int.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.age) +# expect: builtins.int diff --git a/tests_mypy/case_list.py b/tests_mypy/case_list.py new file mode 100644 index 0000000..87bf2d4 --- /dev/null +++ b/tests_mypy/case_list.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.pet_names) +# expect: builtins.list[builtins.str] diff --git a/tests_mypy/case_map.py b/tests_mypy/case_map.py new file mode 100644 index 0000000..dead5dc --- /dev/null +++ b/tests_mypy/case_map.py @@ -0,0 +1,4 @@ +from models import car_registry + +reveal_type(car_registry.registry) +# expect: builtins.dict[builtins.str, builtins.str] diff --git a/tests_mypy/case_nullable.py b/tests_mypy/case_nullable.py new file mode 100644 index 0000000..0af3d9e --- /dev/null +++ b/tests_mypy/case_nullable.py @@ -0,0 +1,7 @@ +from models import address + +reveal_type(address.line_1) +# expect: builtins.str + +reveal_type(address.line_2) +# expect: Union[builtins.str, None] diff --git a/tests_mypy/case_str.py b/tests_mypy/case_str.py new file mode 100644 index 0000000..fbc146f --- /dev/null +++ b/tests_mypy/case_str.py @@ -0,0 +1,4 @@ +from models import person + +reveal_type(person.name) +# expect: builtins.str diff --git a/tests_mypy/case_subfield.py b/tests_mypy/case_subfield.py new file mode 100644 index 0000000..694d95f --- /dev/null +++ b/tests_mypy/case_subfield.py @@ -0,0 +1,13 @@ +from jsonmodels import models, fields + + +class SubStringField(fields.StringField): + pass + +class TestModel(models.Base): + name = SubStringField() + +test_instance = TestModel() + +reveal_type(test_instance.name) +# expect: builtins.str diff --git a/tests_mypy/models.py b/tests_mypy/models.py new file mode 100644 index 0000000..8bc2e9b --- /dev/null +++ b/tests_mypy/models.py @@ -0,0 +1,34 @@ +from jsonmodels import models, fields + +class Address(models.Base): + line_1 = fields.StringField() + line_2 = fields.StringField(nullable=True) + city = fields.StringField() + +class Car(models.Base): + registration = fields.StringField() + +class Boat(models.Base): + name = fields.StringField() + +class Person(models.Base): + name = fields.StringField() + surname = fields.StringField() + age = fields.IntField() + dob = fields.DateField() + alive = fields.BoolField() + last_update = fields.DateTimeField() + address = fields.EmbeddedField(model_types=Address) + transport = fields.EmbeddedField(model_types=(Car, Boat)) + pet_names = fields.ListField(items_types=str) + nicknames = fields.DerivedListField(field=fields.StringField()) + alias_names = fields.DerivedListField(fields.StringField()) + +class CarRegistry(models.Base): + registry = fields.MapField(key_field=fields.StringField(), value_field=fields.StringField()) + any_random = fields.GenericField() + + +person = Person() +address = Address() +car_registry = CarRegistry() diff --git a/tests_mypy/test_mypy_plugin.py b/tests_mypy/test_mypy_plugin.py new file mode 100644 index 0000000..6238b43 --- /dev/null +++ b/tests_mypy/test_mypy_plugin.py @@ -0,0 +1,50 @@ +from mypy import api +import os + + +EXPECT_LINE = "# expect: " +EXPECT_LINE_OUTPUT = "Revealed type is " + + +def test_file(directory: str, file_name: str) -> bool: + expected: list[str] = [] + file_path = os.path.join(directory, file_name) + with open(file_path, 'r') as f: + lines = f.readlines() + for line in lines: + if line.startswith(EXPECT_LINE): + expected.append(line[len(EXPECT_LINE):].strip()) + + result = api.run([ + "--config-file=../mypy_plugin.ini", + "--show-traceback", + file_path]) + + output_expected: list[str] = [] + for output_line in result[0].splitlines(): + index = output_line.find(EXPECT_LINE_OUTPUT) + if index > 0: + output_expected.append(output_line[index + len(EXPECT_LINE_OUTPUT):].strip().strip('"')) + + if expected == output_expected: + print(f"PASS {file_name}") + return True + else: + print(f"FAIL {file_name}\n") + print(f"Expected: {repr(expected)}") + print(f"Received: {repr(output_expected)}") + print("STDOUT----------------") + print(result[0]) + print(result[1]) + print("----------------------") + return False + +def main() -> None: + directory = '.' + files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and f.startswith("case_")] + files.sort() + + for file_name in files: + test_file(directory, file_name) + +main()