Skip to content

Commit

Permalink
Merge pull request #23 from PrefectHQ/reconcile-with-pydantic-v2
Browse files Browse the repository at this point in the history
Updating the tests to be tolerant of the differences between Pydantic v1 and v2
  • Loading branch information
chrisguidry authored May 29, 2024
2 parents 8272664 + 51bfcad commit edc97e0
Showing 1 changed file with 120 additions and 38 deletions.
158 changes: 120 additions & 38 deletions test_oss_cloud_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import re
from typing import Any

import pytest

Expand Down Expand Up @@ -92,6 +93,20 @@ def convert_oss_endpoint_to_cloud(endpoint):
return endpoint


def lookup_content_body_schema(body: dict[str, Any]) -> dict[str, Any] | None:
"""Given the schema for an endpoint, find the JSON response's content schema"""
schema = body.get("content", {}).get("application/json", {}).get("schema", {})

# In pydantic v1, the schema reference is a single value
if "$ref" in schema:
return schema.get("$ref")
# In pydantic v2, the schema reference is an `allOf` with a single item
elif "allOf" in schema:
return schema.get("allOf", [{}])[0].get("$ref")

return None


OSS_PATHS = generate_oss_paths_by_method()
OSS_TYPES = generate_oss_types()

Expand Down Expand Up @@ -133,9 +148,15 @@ def test_api_path_parameters_are_compatible(oss_path, cloud_paths):

def param_type_and_format(schema):
if "anyOf" in schema:
return [(item["type"], item.get("format")) for item in schema["anyOf"]]
# Pydantic v2 renders optional fields with `anyOf` (type, null), but the
# Pydantic v1 does not, so let's strip all the additional `null` types out
return [
(item["type"], item.get("format"))
for item in schema["anyOf"]
if item.get("type") != "null"
]
else:
return (schema.get("type"), schema.get("format"))
return [(schema.get("type"), schema.get("format"))]

# check schemas
cloud_params = {
Expand Down Expand Up @@ -190,18 +211,8 @@ def test_api_request_bodies_are_compatible(oss_path, oss_schema, cloud_schema):
cloud_body = cloud_paths[cloud_endpoint][method].get("requestBody", {})
oss_body = path[method].get("requestBody", {})

cloud_body_schema = (
cloud_body.get("content", {})
.get("application/json", {})
.get("schema", {})
.get("$ref")
)
oss_body_schema = (
oss_body.get("content", {})
.get("application/json", {})
.get("schema", {})
.get("$ref")
)
cloud_body_schema = lookup_content_body_schema(cloud_body)
oss_body_schema = lookup_content_body_schema(oss_body)

cloud_ref_schema = lookup_schema_ref(
schema=cloud_schema, ref=cloud_body_schema
Expand Down Expand Up @@ -239,7 +250,10 @@ def extract_types(d):

cloud_props = (
cloud_ref_schema["type"],
{name: prop_gettr(name, d) for name, d in cloud_ref_schema["properties"].items()},
{
name: prop_gettr(name, d)
for name, d in cloud_ref_schema["properties"].items()
},
)
oss_props = (
oss_ref_schema["type"],
Expand All @@ -255,11 +269,58 @@ def extract_types(d):
# - new Cloud fields aren't required (this is difficult to check right now as it's method dependent!)
assert cloud_props[0] == oss_props[0]

# Handling of aliases is different between Pydantic v2 and v1, so we'll force
# some name overrides here
KNOWN_ALIASES = {
"/api/flow_runs/history": {
"post": {
"history_interval": "history_interval_seconds",
}
},
"/api/task_runs/history": {
"post": {
"history_interval": "history_interval_seconds",
}
},
"/api/ui/schemas/validate": {
"post": {
"json_schema": "schema",
}
},
}

# ensure every OSS field is present in Cloud
# ensure the property attributes are the same or a subset (like in the case of type)
for field_name, (oss_name, oss_types, oss_format, oss_default, oss_deprecated) in oss_props[1].items():
assert field_name in cloud_props[1]
(cloud_name, cloud_types, cloud_format, cloud_default, cloud_deprecated) = cloud_props[1][field_name]
for (
oss_name,
oss_types,
oss_format,
oss_default,
oss_deprecated,
) in oss_props[1].values():
if endpoint in KNOWN_ALIASES:
if method in KNOWN_ALIASES[endpoint]:
if oss_name in KNOWN_ALIASES[endpoint][method]:
oss_name = KNOWN_ALIASES[endpoint][method][oss_name]

assert oss_name in cloud_props[1]
(
cloud_name,
cloud_types,
cloud_format,
cloud_default,
cloud_deprecated,
) = cloud_props[1][oss_name]

# In Pydantic v2, if a field is not required, it's format is not included, so
# we need to remove it from the comparison
if "null" in oss_types and oss_format is None and cloud_format is not None:
cloud_format = None

# While OSS and Cloud are on different versions of pydantic, there is a
# discrepancy where any option OSS type (correctly) includes `anyOf` `null`
# while Cloud does not.
oss_types.discard("null")

assert oss_name == cloud_name
assert oss_types <= cloud_types
Expand All @@ -268,19 +329,41 @@ def extract_types(d):
assert oss_deprecated == cloud_deprecated


@pytest.mark.parametrize("oss_type", OSS_TYPES, ids=[name for (name, _) in OSS_TYPES])
def test_oss_api_types_are_cloud_compatible(oss_type, cloud_schema):
@pytest.mark.parametrize(
"oss_name_and_type", OSS_TYPES, ids=[name for (name, _) in OSS_TYPES]
)
def test_oss_api_types_are_cloud_compatible(oss_name_and_type, cloud_schema):
cloud_types = cloud_schema["components"]["schemas"]
name, typ = oss_type
name, oss_type = oss_name_and_type

# ignore missing for now, as there are name incompatibilities to study
if name not in cloud_types:
try:
cloud_type = cloud_types[name]
except KeyError:
return

# preprocess pydantic v1 schema to match pydantic v2 schema
def preprocess_pydantic_v1_type(schema):
# transform any non-required fields to by anyOf (null, type)
for field_name, props in schema.get("properties", {}).items():
required_fields = schema.get("required", [])
if field_name not in required_fields:
current_definition = schema["properties"][field_name]
if "anyOf" in current_definition:
schema["properties"][field_name]["anyOf"].append({"type": "null"})
else:
schema["properties"][field_name] = {
"anyOf": [{"type": "null"}, props]
}
required_fields.append(field_name)
return schema

cloud_type = preprocess_pydantic_v1_type(cloud_type)

for master_key in ["properties", "required", "enum", "type"]:
oss_props, cloud_props = (
typ.get(master_key, {}),
cloud_types[name].get(master_key, {}),
oss_type.get(master_key, {}),
cloud_type.get(master_key, {}),
)

if not isinstance(oss_props, dict):
Expand Down Expand Up @@ -310,23 +393,22 @@ def test_oss_api_types_are_cloud_compatible(oss_type, cloud_schema):
if props.get("type"):
oss_options = {props.get("type")}
elif props.get("anyOf"):
oss_options = {opt.get("type") for opt in props.get("anyOf") if opt.get("type")}
oss_options = {
opt.get("type") for opt in props.get("anyOf") if opt.get("type")
}

if cloud_props[field_name].get("type"):
cloud_options = {cloud_props[field_name].get("type")}
elif cloud_props[field_name].get("anyOf"):
cloud_options = {opt.get("type") for opt in cloud_props[field_name].get("anyOf") if opt.get("type")}

assert oss_options <= cloud_options
cloud_options = {
opt.get("type")
for opt in cloud_props[field_name].get("anyOf")
if opt.get("type")
}

# While OSS and Cloud are on different versions of pydantic, there is a
# discrepancy where any option OSS type (correctly) includes `anyOf` `null`
# while Cloud does not.
oss_options.discard("null")

def deep_tuple(o):
"""Given an object `o` that is either a dictionary, a list, or any other hashable
type, return a tuple representation of it where all components are also converted
into tuples."""
if isinstance(o, dict):
return tuple((k, deep_tuple(v)) for k, v in o.items())
elif isinstance(o, list):
return tuple(deep_tuple(v) for v in o)
else:
return o
assert oss_options <= cloud_options

0 comments on commit edc97e0

Please sign in to comment.