From 41d89d13d176a7f9d062d1935be1e6731fd968e5 Mon Sep 17 00:00:00 2001 From: arily Date: Tue, 5 Dec 2023 16:41:36 +0900 Subject: [PATCH] pattern match cond --- app/query_builder.py | 74 +++++++++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 29 deletions(-) diff --git a/app/query_builder.py b/app/query_builder.py index 52859bf46..c04a8e8f3 100644 --- a/app/query_builder.py +++ b/app/query_builder.py @@ -30,13 +30,14 @@ # print(params) # Output the parameters dictionary # _example() - from typing import Tuple, Dict, Optional, Callable, Union DatabaseAllowedNotNull = Union[str, int, bool, float] Value = Union[DatabaseAllowedNotNull, None] -SQLBasic = Union[str, Tuple[Value, str]] -SQLPart = Union[SQLBasic, Tuple[SQLBasic, ...], Callable[[], SQLBasic]] +SQLValueWithTemplate = Tuple[Value, str] +SQLValueWithNested = Tuple[Value, "SQLPart"] +SQLType = Union[SQLValueWithTemplate, SQLValueWithNested] +SQLPart = Union[SQLType, Tuple[SQLType, ...], Callable[[], SQLType]] class Nullable: @@ -49,7 +50,7 @@ def nullable(value: DatabaseAllowedNotNull | None) -> Nullable: def build( - *parts: SQLPart, + *parts: SQLPart | str, ) -> Tuple[str, Dict[str, Value]]: parameters = {} query_parts = (_process_query_part(p, parameters) for p in parts) @@ -68,15 +69,20 @@ def _extract_value(value) -> Value: return value -def _process_query_part(part: SQLPart, parameters: Dict[str, Value]) -> Optional[str]: +def _process_query_part( + part: SQLPart | str, parameters: Dict[str, Value] +) -> Optional[str]: + # late evaluation if callable(part): part = part() - if part is None: - return None - if isinstance(part, str): - return part - if isinstance(part, tuple): - return _process_tuple_part(part, parameters) + + match part: + case None: + return None + case str(literal): + return literal + case tuple(parts): + return _process_tuple_part(parts, parameters) raise TypeError(f"Unexpected type for query part: {type(part)}") @@ -85,21 +91,31 @@ def _process_tuple_part( part: SQLPart, parameters: Dict[str, Value], ) -> Optional[str]: - value, query_part = part - extracted_value = _extract_value(value) - - if isinstance(value, tuple): - parts: list[str] = [] - for elem in part: - return_value = _process_query_part(elem, parameters) - if return_value is None: - return None - parts.append(return_value) - return " ".join(parts) - - if extracted_value is not None or _is_nullable(value): - if ":" in query_part and extracted_value is not None: - parameter_name = query_part.split(":")[-1].strip().split(" ")[0].strip() - parameters[parameter_name] = extracted_value - return query_part - return None + match part: + case ( + Nullable(value=cond) | bool(cond) | str(cond) | int(cond) | float(cond), + val, + ): + evaluated = _process_query_part(val, parameters) + + revealed_cond = _extract_value(cond) + print("revealed", revealed_cond, evaluated) + if revealed_cond is not None or _is_nullable(cond): + if ":" in evaluated and revealed_cond is not None: + parameter_name = ( + evaluated.split(":")[-1].strip().split(" ")[0].strip() + ) + parameters[parameter_name] = revealed_cond + return evaluated + + case tuple(_), *_: + parts: list[str] = [] + for elem in part: + return_value = _process_query_part(elem, parameters) + if return_value is None: + return None + parts.append(return_value) + return " ".join(parts) + + case _: + raise TypeError(f"Unexpected type for query part: {type(part)}")