Skip to content

Commit

Permalink
Use result instead of exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
expobrain committed Oct 13, 2024
1 parent 1a403d6 commit 8f25263
Show file tree
Hide file tree
Showing 13 changed files with 429 additions and 282 deletions.
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pyyaml = ">=5.4"
loguru = ">=0.7"
typing-extensions = ">=4.6"
greenlet = ">=3"
result = "^0.17.0"

[tool.poetry.group.dev.dependencies]
mypy = "^1.11"
Expand Down
17 changes: 13 additions & 4 deletions sqlalchemy_to_json_schema/command/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Optional, Union, cast

import yaml
from result import Err, Ok, Result
from sqlalchemy.ext.declarative import DeclarativeMeta

from sqlalchemy_to_json_schema.command.transformer import (
Expand Down Expand Up @@ -57,7 +58,9 @@ def __init__(self, walker: Walker, decision: Decision, layout: Layout, /):

def build_transformer(
self, walker: Walker, decision: Decision, layout: Layout, /
) -> Callable[[Iterable[Union[ModuleType, DeclarativeMeta]], Optional[int]], Schema]:
) -> Callable[
[Iterable[Union[ModuleType, DeclarativeMeta]], Optional[int]], Result[Schema, str]
]:
walker_factory = WALKER_MAP[walker]
relation_decision = DECISION_MAP[decision]()
schema_factory = SchemaFactory(walker_factory, relation_decision=relation_decision)
Expand All @@ -73,7 +76,7 @@ def run(
filename: Optional[Path] = None,
format: Optional[Format] = None,
depth: Optional[int] = None,
) -> None:
) -> Result[None, str]:
modules_and_types = (load_module_or_symbol(target) for target in targets)
modules_and_models = cast(
Iterator[Union[ModuleType, DeclarativeMeta]],
Expand All @@ -84,8 +87,14 @@ def run(
),
)

result = self.transformer(modules_and_models, depth)
self.dump(result, filename=filename, format=format)
schema = self.transformer(modules_and_models, depth)

if schema.is_err():
return Err(schema.unwrap_err())

self.dump(schema.unwrap(), filename=filename, format=format)

return Ok(None)

def dump(
self,
Expand Down
81 changes: 59 additions & 22 deletions sqlalchemy_to_json_schema/command/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Union

from loguru import logger
from result import Err, Ok, Result
from sqlalchemy.ext.declarative import DeclarativeMeta
from typing_extensions import TypeGuard

Expand All @@ -18,13 +19,13 @@ def __init__(self, schema_factory: SchemaFactory, /):
@abstractmethod
def transform(
self, rawtargets: Iterable[Union[ModuleType, DeclarativeMeta]], depth: Optional[int], /
) -> Schema: ...
) -> Result[Schema, str]: ...


class JSONSchemaTransformer(AbstractTransformer):
def transform(
self, rawtargets: Iterable[Union[ModuleType, DeclarativeMeta]], depth: Optional[int], /
) -> Schema:
) -> Result[Schema, str]:
definitions = {}

for item in rawtargets:
Expand All @@ -33,33 +34,46 @@ def transform(
elif inspect.ismodule(item):
partial_definitions = self.transform_by_module(item, depth)
else:
TypeError(f"Expected a class or module, got {item}")
return Err(f"Expected a class or module, got {item}")

definitions.update(partial_definitions)
if partial_definitions.is_err():
return partial_definitions

return definitions
definitions.update(partial_definitions.unwrap())

def transform_by_model(self, model: DeclarativeMeta, depth: Optional[int], /) -> Schema:
return Ok(definitions)

def transform_by_model(
self, model: DeclarativeMeta, depth: Optional[int], /
) -> Result[Schema, str]:
return self.schema_factory(model, depth=depth)

def transform_by_module(self, module: ModuleType, depth: Optional[int], /) -> Schema:
def transform_by_module(
self, module: ModuleType, depth: Optional[int], /
) -> Result[Schema, str]:
subdefinitions = {}
definitions = {}
for basemodel in collect_models(module):
schema = self.schema_factory(basemodel, depth=depth)
schema_result = self.schema_factory(basemodel, depth=depth)

if schema_result.is_err():
return schema_result

schema = schema_result.unwrap()

if "definitions" in schema:
subdefinitions.update(schema.pop("definitions"))
definitions[schema["title"]] = schema
d = {}
d.update(subdefinitions)
d.update(definitions)
return {"definitions": definitions}
return Ok({"definitions": definitions})


class OpenAPI2Transformer(AbstractTransformer):
def transform(
self, rawtargets: Iterable[Union[ModuleType, DeclarativeMeta]], depth: Optional[int], /
) -> Schema:
) -> Result[Schema, str]:
definitions = {}

for target in rawtargets:
Expand All @@ -68,29 +82,46 @@ def transform(
elif inspect.ismodule(target):
partial_definitions = self.transform_by_module(target, depth)
else:
raise TypeError(f"Expected a class or module, got {target}")
return Err(f"Expected a class or module, got {target}")

if partial_definitions.is_err():
return partial_definitions

definitions.update(partial_definitions)
definitions.update(partial_definitions.unwrap())

return {"definitions": definitions}
return Ok({"definitions": definitions})

def transform_by_model(self, model: DeclarativeMeta, depth: Optional[int], /) -> Schema:
def transform_by_model(
self, model: DeclarativeMeta, depth: Optional[int], /
) -> Result[Schema, str]:
definitions = {}
schema = self.schema_factory(model, depth=depth)
schema_result = self.schema_factory(model, depth=depth)

if schema_result.is_err():
return schema_result

schema = schema_result.unwrap()

if "definitions" in schema:
definitions.update(schema.pop("definitions"))

definitions[schema["title"]] = schema

return definitions
return Ok(definitions)

def transform_by_module(self, module: ModuleType, depth: Optional[int], /) -> Schema:
def transform_by_module(
self, module: ModuleType, depth: Optional[int], /
) -> Result[Schema, str]:
subdefinitions = {}
definitions = {}

for basemodel in collect_models(module):
schema = self.schema_factory(basemodel, depth=depth)
schema_result = self.schema_factory(basemodel, depth=depth)

if schema_result.is_err():
return schema_result

schema = schema_result.unwrap()

if "definitions" in schema:
subdefinitions.update(schema.pop("definitions"))
Expand All @@ -101,7 +132,7 @@ def transform_by_module(self, module: ModuleType, depth: Optional[int], /) -> Sc
d.update(subdefinitions)
d.update(definitions)

return definitions
return Ok(definitions)


class OpenAPI3Transformer(OpenAPI2Transformer):
Expand All @@ -118,8 +149,13 @@ def replace_ref(self, d: Union[dict, list], old_prefix: str, new_prefix: str, /)

def transform(
self, rawtargets: Iterable[Union[ModuleType, DeclarativeMeta]], depth: Optional[int], /
) -> Schema:
definitions = super().transform(rawtargets, depth)
) -> Result[Schema, str]:
definitions_result = super().transform(rawtargets, depth)

if definitions_result.is_err():
return Err(definitions_result.unwrap_err())

definitions = definitions_result.unwrap()

self.replace_ref(definitions, "#/definitions/", "#/components/schemas/")

Expand All @@ -128,7 +164,8 @@ def transform(
if "schemas" not in definitions["components"]:
definitions["components"]["schemas"] = {}
definitions["components"]["schemas"] = definitions.pop("definitions", {})
return definitions

return Ok(definitions)


def collect_models(module: ModuleType, /) -> Iterator[DeclarativeMeta]:
Expand Down
49 changes: 30 additions & 19 deletions sqlalchemy_to_json_schema/decisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Iterator
from typing import Any, Union

from result import Err, Ok, Result
from sqlalchemy.orm import MapperProperty
from sqlalchemy.orm.base import MANYTOMANY, MANYTOONE
from sqlalchemy.orm.properties import ColumnProperty
Expand All @@ -24,7 +25,7 @@ def decision(
/,
*,
toplevel: bool = False,
) -> Iterator[DecisionResult]:
) -> Iterator[Result[DecisionResult, MapperProperty]]:
pass


Expand All @@ -36,13 +37,13 @@ def decision(
/,
*,
toplevel: bool = False,
) -> Iterator[DecisionResult]:
) -> Iterator[Result[DecisionResult, MapperProperty]]:
if hasattr(prop, "mapper"):
yield ColumnPropertyType.RELATIONSHIP, prop, {}
yield Ok((ColumnPropertyType.RELATIONSHIP, prop, {}))
elif hasattr(prop, "columns"):
yield ColumnPropertyType.FOREIGNKEY, prop, {}
yield Ok((ColumnPropertyType.FOREIGNKEY, prop, {}))
else:
raise NotImplementedError(prop)
yield Err(prop)


class UseForeignKeyIfPossibleDecision(AbstractDecision):
Expand All @@ -53,32 +54,42 @@ def decision(
/,
*,
toplevel: bool = False,
) -> Iterator[DecisionResult]:
) -> Iterator[Result[DecisionResult, MapperProperty]]:
if hasattr(prop, "mapper"):
if prop.direction == MANYTOONE:
if toplevel:
for c in prop.local_columns:
yield ColumnPropertyType.FOREIGNKEY, walker.mapper._props[c.name], {
"relation": prop.key
}
yield Ok(
(
ColumnPropertyType.FOREIGNKEY,
walker.mapper._props[c.name],
{"relation": prop.key},
)
)
else:
rp = walker.history[0]
if prop.local_columns != rp.remote_side:
for c in prop.local_columns:
yield ColumnPropertyType.FOREIGNKEY, walker.mapper._props[c.name], {
"relation": prop.key
}
yield Ok(
(
ColumnPropertyType.FOREIGNKEY,
walker.mapper._props[c.name],
{"relation": prop.key},
)
)
elif prop.direction == MANYTOMANY:
# logger.warning("skip mapper=%s, prop=%s is many to many.", walker.mapper, prop)
# fixme: this must return a ColumnPropertyType member
yield (
{"type": "array", "items": {"type": "string"}}, # type: ignore[misc]
prop,
{},
yield Ok(
( # type: ignore[arg-type]
{"type": "array", "items": {"type": "string"}},
prop,
{},
)
)
else:
yield ColumnPropertyType.RELATIONSHIP, prop, {}
yield Ok((ColumnPropertyType.RELATIONSHIP, prop, {}))
elif hasattr(prop, "columns"):
yield ColumnPropertyType.FOREIGNKEY, prop, {}
yield Ok((ColumnPropertyType.FOREIGNKEY, prop, {}))
else:
raise NotImplementedError(prop)
yield Err(prop)
Loading

0 comments on commit 8f25263

Please sign in to comment.