Skip to content

Commit

Permalink
query builder
Browse files Browse the repository at this point in the history
  • Loading branch information
arily committed Dec 5, 2023
1 parent 9c30c19 commit dae4420
Showing 1 changed file with 96 additions and 0 deletions.
96 changes: 96 additions & 0 deletions app/query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Example usage
# def _example():
# READ_PARAMS = "column1, column2"
# map_md5 = nullable("some_md5_value") # Use nullable(None) to include NULL
# user_id = 123

# query, params = build(
# f"SELECT {READ_PARAMS} FROM scores WHERE 1 = 1",
# (map_md5, "AND map_md5 = :map_md5"),
# (
# (user_id, "AND user_id = :user_id"),
# (nullable(None), "AND some_other_field IS NULL"),
# ),
# (None, "AND test_none = :test_none OR test_some = :test_none"),
# (
# (1, "AND test_nested = :test_nested"),
# (None, "AND nested_cond = :test_nested_none"),
# ),
# )

# print(query) # Output the constructed query
# print(params) # Output the parameters dictionary


from typing import Tuple, Dict, Union, Optional

DatabaseAllowedNotNull = Union[str, int, bool, float]
Nullable = Union[None, DatabaseAllowedNotNull]
SQLBasic = Union[Tuple[Nullable, str], str]
SQLPart = Union[SQLBasic, Tuple[Union[SQLBasic, "SQLPart"], ...]]


class Nullable:
def __init__(self, value: Optional[DatabaseAllowedNotNull]):
self.value = value


def nullable(value: Optional[DatabaseAllowedNotNull]) -> Nullable:
return Nullable(value)


def build(
*parts: SQLPart,
) -> Tuple[str, Dict[str, Union[DatabaseAllowedNotNull, None]]]:
parameters = {}
query_parts = (_process_query_part(p, parameters) for p in parts)

query = " ".join(q for q in query_parts if q is not None)
return query, parameters


def _is_nullable(value) -> bool:
return isinstance(value, Nullable)


def _extract_value(value) -> Union[DatabaseAllowedNotNull, None]:
if _is_nullable(value):
return value.value
return value


def _process_query_part(
part: SQLPart, parameters: Dict[str, Union[DatabaseAllowedNotNull, None]]
) -> Optional[str]:
if part is None:
return None
if isinstance(part, str):
return part
if isinstance(part, tuple):
return _process_tuple_part(part, parameters)

raise TypeError(f"Unexpected type for query part: {type(part)}")


def _process_tuple_part(
part: SQLPart,
parameters: Dict[str, Union[DatabaseAllowedNotNull, None]],
) -> Optional[str]:
value, query_part = part
extracted_value = _extract_value(value)

if isinstance(value, tuple):
parts: list[str] = []
for elem in part:
return_value = _process_tuple_part(elem, parameters)
if return_value is None:
return None
parts.append(return_value)
return " ".join(parts)

if extracted_value is not None or _is_nullable(value):
if ":" in query_part and extracted_value is not None:
parameter_name = query_part.split(":")[-1].strip().split(" ")[0].strip()
parameters[parameter_name] = extracted_value
return query_part
return None

0 comments on commit dae4420

Please sign in to comment.