Skip to content

Commit

Permalink
Extract schema transformation in module
Browse files Browse the repository at this point in the history
  • Loading branch information
mdellweg committed Nov 9, 2024
1 parent 51e18be commit 65bd8bf
Show file tree
Hide file tree
Showing 4 changed files with 622 additions and 198 deletions.
202 changes: 5 additions & 197 deletions pulp-glue/pulp_glue/common/openapi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
# copyright (c) 2020, Matthias Dellweg
# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt)

import base64
import datetime
import json
import os
import typing as t
from collections import defaultdict
from contextlib import suppress
from io import BufferedReader
from urllib.parse import urljoin

Expand All @@ -16,15 +13,14 @@

from pulp_glue.common import __version__
from pulp_glue.common.i18n import get_translation
from pulp_glue.common.schema import ValidationError, transform

translation = get_translation(__package__)
_ = translation.gettext

UploadType = t.Union[bytes, t.IO[bytes]]

SAFE_METHODS = ["GET", "HEAD", "OPTIONS"]
ISO_DATE_FORMAT = "%Y-%m-%d"
ISO_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"


class OpenAPIError(Exception):
Expand Down Expand Up @@ -363,198 +359,10 @@ def extract_params(
return result

def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any:
# Look if the schema is provided by reference
schema_ref = schema.get("$ref")
if schema_ref:
if not schema_ref.startswith("#/components/schemas/"):
raise OpenAPIError(_("Api spec is invalid."))
# len("#/components/schemas/") == 21
schema_name = schema_ref[21:]
schema = self.api_spec["components"]["schemas"][schema_name]

if value is None and schema.get("nullable", False):
return None

schema_type = schema.get("type")
allOf = schema.get("allOf")
anyOf = schema.get("anyOf")
oneOf = schema.get("oneOf")
not_schema = schema.get("not")
if allOf:
old_value = value
value = self.validate_schema(allOf[0], name, value)
for sub_schema in allOf[1:]:
# TODO check if it is possible to combine non object types.
value.update(self.validate_schema(sub_schema, name, old_value))
elif anyOf:
for sub_schema in anyOf:
with suppress(OpenAPIValidationError):
value = self.validate_schema(sub_schema, name, value)
break
else:
raise OpenAPIValidationError(
_("No schema in anyOf validated for {name}.").format(name=name)
)
elif oneOf:
old_value = value
found_valid = False
for sub_schema in anyOf:
with suppress(OpenAPIValidationError):
value = self.validate_schema(sub_schema, name, old_value)
if found_valid:
raise OpenAPIValidationError(
_("Multiple schemas in oneOf validated for {name}.").format(name=name)
)
found_valid = True
if not found_valid:
raise OpenAPIValidationError(
_("No schema in oneOf validated for {name}.").format(name=name)
)
elif not_schema:
try:
self.validate_schema(not_schema, name, value)
except OpenAPIValidationError:
pass
else:
raise OpenAPIValidationError(
_("Forbidden schema for {name} validated.").format(name=name)
)
elif schema_type is None:
# Schema type is not specified.
# JSONField
pass
elif schema_type == "object":
# Serializer
value = self.validate_object(schema, name, value)
elif schema_type == "array":
# ListField
value = self.validate_array(schema, name, value)
elif schema_type == "string":
# CharField
# TextField
# DateTimeField etc.
# ChoiceField
# FileField (binary data)
value = self.validate_string(schema, name, value)
elif schema_type == "integer":
# IntegerField
value = self.validate_integer(schema, name, value)
elif schema_type == "number":
# FloatField
value = self.validate_number(schema, name, value)
elif schema_type == "boolean":
# BooleanField
if not isinstance(value, bool):
raise OpenAPIValidationError(
_("'{name}' is expected to be a boolean.").format(name=name)
)
# TODO: Add more types here.
else:
raise OpenAPIError(
_("Type `{schema_type}` is not implemented yet.").format(schema_type=schema_type)
)
return value

def validate_object(self, schema: t.Any, name: str, value: t.Any) -> t.Dict[str, t.Any]:
if not isinstance(value, t.Dict):
raise OpenAPIValidationError(
_("'{name}' is expected to be an object.").format(name=name)
)
properties = schema.get("properties", {})
additional_properties = schema.get("additionalProperties")
if properties or additional_properties is not None:
value = value.copy()
for property_name, property_value in value.items():
property_schema = properties.get(property_name, additional_properties)
if not property_schema:
raise OpenAPIValidationError(
_("Unexpected property '{property_name}' for '{name}' provided.").format(
name=name, property_name=property_name
)
)
value[property_name] = self.validate_schema(
property_schema, property_name, property_value
)
if "required" in schema:
missing_properties = set(schema["required"]) - set(value.keys())
if missing_properties:
raise OpenAPIValidationError(
_("Required properties(s) '{missing_properties}' of '{name}' missing.").format(
name=name, missing_properties=missing_properties
)
)
return value

def validate_array(self, schema: t.Any, name: str, value: t.Any) -> t.List[t.Any]:
if not isinstance(value, t.List):
raise OpenAPIValidationError(_("'{name}' is expected to be a list.").format(name=name))
item_schema = schema["items"]
return [self.validate_schema(item_schema, name, item) for item in value]

def validate_string(self, schema: t.Any, name: str, value: t.Any) -> t.Union[str, UploadType]:
enum = schema.get("enum")
if enum:
if value not in enum:
raise OpenAPIValidationError(
_("'{name}' is not one of the valid choices.").format(name=name)
)
schema_format = schema.get("format")
if schema_format == "date":
if not isinstance(value, datetime.date):
raise OpenAPIValidationError(
_("'{name}' is expected to be a date.").format(name=name)
)
return value.strftime(ISO_DATE_FORMAT)
elif schema_format == "date-time":
if not isinstance(value, datetime.datetime):
raise OpenAPIValidationError(
_("'{name}' is expected to be a datetime.").format(name=name)
)
return value.strftime(ISO_DATETIME_FORMAT)
elif schema_format == "bytes":
if not isinstance(value, bytes):
raise OpenAPIValidationError(
_("'{name}' is expected to be bytes.").format(name=name)
)
return base64.b64encode(value)
elif schema_format == "binary":
if not isinstance(value, (bytes, BufferedReader)):
raise OpenAPIValidationError(
_("'{name}' is expected to be binary.").format(name=name)
)
return value
else:
if not isinstance(value, str):
raise OpenAPIValidationError(
_("'{name}' is expected to be a string.").format(name=name)
)
return value

def validate_integer(self, schema: t.Any, name: str, value: t.Any) -> int:
if not isinstance(value, int):
raise OpenAPIValidationError(
_("'{name}' is expected to be an integer.").format(name=name)
)
minimum = schema.get("minimum")
if minimum is not None and value < minimum:
raise OpenAPIValidationError(
_("'{name}' is violating the minimum constraint").format(name=name)
)
maximum = schema.get("maximum")
if maximum is not None and value > maximum:
raise OpenAPIValidationError(
_("'{name}' is violating the maximum constraint").format(name=name)
)
return value

def validate_number(self, schema: t.Any, name: str, value: t.Any) -> float:
# https://swagger.io/specification/#data-types describes float and double.
# Python does not distinguish them.
if not isinstance(value, float):
raise OpenAPIValidationError(
_("'{name}' is expected to be a number.").format(name=name)
)
return value
try:
return transform(schema, name, value, self.api_spec["components"]["schemas"])
except ValidationError as e:
raise OpenAPIValidationError(str(e)) from e

def render_request_body(
self,
Expand Down
Loading

0 comments on commit 65bd8bf

Please sign in to comment.