Skip to content

Commit

Permalink
pattern match cond
Browse files Browse the repository at this point in the history
  • Loading branch information
arily committed Dec 5, 2023
1 parent db13a33 commit 41d89d1
Showing 1 changed file with 45 additions and 29 deletions.
74 changes: 45 additions & 29 deletions app/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
# print(params) # Output the parameters dictionary

# _example()

from typing import Tuple, Dict, Optional, Callable, Union

DatabaseAllowedNotNull = Union[str, int, bool, float]
Value = Union[DatabaseAllowedNotNull, None]
SQLBasic = Union[str, Tuple[Value, str]]
SQLPart = Union[SQLBasic, Tuple[SQLBasic, ...], Callable[[], SQLBasic]]
SQLValueWithTemplate = Tuple[Value, str]
SQLValueWithNested = Tuple[Value, "SQLPart"]
SQLType = Union[SQLValueWithTemplate, SQLValueWithNested]
SQLPart = Union[SQLType, Tuple[SQLType, ...], Callable[[], SQLType]]


class Nullable:
Expand All @@ -49,7 +50,7 @@ def nullable(value: DatabaseAllowedNotNull | None) -> Nullable:


def build(
*parts: SQLPart,
*parts: SQLPart | str,
) -> Tuple[str, Dict[str, Value]]:
parameters = {}
query_parts = (_process_query_part(p, parameters) for p in parts)
Expand All @@ -68,15 +69,20 @@ def _extract_value(value) -> Value:
return value


def _process_query_part(part: SQLPart, parameters: Dict[str, Value]) -> Optional[str]:
def _process_query_part(
part: SQLPart | str, parameters: Dict[str, Value]
) -> Optional[str]:
# late evaluation
if callable(part):
part = part()
if part is None:
return None
if isinstance(part, str):
return part
if isinstance(part, tuple):
return _process_tuple_part(part, parameters)

match part:
case None:
return None
case str(literal):
return literal
case tuple(parts):
return _process_tuple_part(parts, parameters)

raise TypeError(f"Unexpected type for query part: {type(part)}")

Expand All @@ -85,21 +91,31 @@ def _process_tuple_part(
part: SQLPart,
parameters: Dict[str, Value],
) -> Optional[str]:
value, query_part = part
extracted_value = _extract_value(value)

if isinstance(value, tuple):
parts: list[str] = []
for elem in part:
return_value = _process_query_part(elem, parameters)
if return_value is None:
return None
parts.append(return_value)
return " ".join(parts)

if extracted_value is not None or _is_nullable(value):
if ":" in query_part and extracted_value is not None:
parameter_name = query_part.split(":")[-1].strip().split(" ")[0].strip()
parameters[parameter_name] = extracted_value
return query_part
return None
match part:
case (
Nullable(value=cond) | bool(cond) | str(cond) | int(cond) | float(cond),
val,
):
evaluated = _process_query_part(val, parameters)

revealed_cond = _extract_value(cond)
print("revealed", revealed_cond, evaluated)
if revealed_cond is not None or _is_nullable(cond):
if ":" in evaluated and revealed_cond is not None:
parameter_name = (
evaluated.split(":")[-1].strip().split(" ")[0].strip()
)
parameters[parameter_name] = revealed_cond
return evaluated

case tuple(_), *_:
parts: list[str] = []
for elem in part:
return_value = _process_query_part(elem, parameters)
if return_value is None:
return None
parts.append(return_value)
return " ".join(parts)

case _:
raise TypeError(f"Unexpected type for query part: {type(part)}")

0 comments on commit 41d89d1

Please sign in to comment.