diff --git a/CHANGELOG.md b/CHANGELOG.md index aadf406..1e18c37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # ChangeLog +## 0.5 + +### 0.5.0 + +- Replace `get_sql` kwargs with `SqlContext` to improve performance + ## 0.4 ### 0.4.0 diff --git a/README.md b/README.md index b799724..ac6fc41 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,12 @@ The original repository includes many databases that Tortoise ORM doesn’t requ ## What changed? -Deleted unnecessary code that Tortoise ORM doesn’t require, and added features tailored specifically for Tortoise ORM. +Deleted unnecessary code that Tortoise ORM doesn’t require, added features tailored specifically for Tortoise ORM, +and modified to improve query generation performance. ## ThanksTo -- [pypika](https://github.com/kayak/pypika), a Python SQL query builder that exposes the full expressiveness of SQL, +- [pypika](https://github.com/kayak/pypika), a Python SQL query builder that exposes the full expressiveness of SQL, using a syntax that mirrors the resulting query structure. ## License diff --git a/pypika_tortoise/__init__.py b/pypika_tortoise/__init__.py index 2bebf65..0328fa7 100644 --- a/pypika_tortoise/__init__.py +++ b/pypika_tortoise/__init__.py @@ -1,3 +1,4 @@ +from .context import SqlContext from .dialects import MSSQLQuery, MySQLQuery, OracleQuery, PostgreSQLQuery, SQLLiteQuery from .enums import DatePart, Dialects, JoinType, Order from .exceptions import ( diff --git a/pypika_tortoise/context.py b/pypika_tortoise/context.py new file mode 100644 index 0000000..4030dd0 --- /dev/null +++ b/pypika_tortoise/context.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class SqlContext: + """Represents the context for get_sql() methods to determine how to render SQL.""" + + quote_char: str + secondary_quote_char: str + alias_quote_char: str + dialect: "Dialects" + as_keyword: bool = False + subquery: bool = False + with_alias: bool = False + with_namespace: bool = False + subcriterion: bool = False + parameterizer: "Parameterizer" | None = None + groupby_alias: bool = True + orderby_alias: bool = True + + def copy(self, **kwargs) -> SqlContext: + return SqlContext( + quote_char=kwargs.get("quote_char", self.quote_char), + secondary_quote_char=kwargs.get("secondary_quote_char", self.secondary_quote_char), + alias_quote_char=kwargs.get("alias_quote_char", self.alias_quote_char), + dialect=kwargs.get("dialect", self.dialect), + as_keyword=kwargs.get("as_keyword", self.as_keyword), + subquery=kwargs.get("subquery", self.subquery), + with_alias=kwargs.get("with_alias", self.with_alias), + with_namespace=kwargs.get("with_namespace", self.with_namespace), + subcriterion=kwargs.get("subcriterion", self.subcriterion), + parameterizer=kwargs.get("parameterizer", self.parameterizer), + groupby_alias=kwargs.get("groupby_alias", self.groupby_alias), + orderby_alias=kwargs.get("orderby_alias", self.orderby_alias), + ) + + +from .enums import Dialects # noqa: E402 + +DEFAULT_SQL_CONTEXT = SqlContext( + quote_char='"', + secondary_quote_char="'", + alias_quote_char="", + as_keyword=False, + dialect=Dialects.SQLITE, +) + +from .terms import Parameterizer # noqa: E402 diff --git a/pypika_tortoise/dialects/mssql.py b/pypika_tortoise/dialects/mssql.py index 43c3897..5d24e5d 100644 --- a/pypika_tortoise/dialects/mssql.py +++ b/pypika_tortoise/dialects/mssql.py @@ -2,6 +2,7 @@ from typing import Any, cast +from ..context import DEFAULT_SQL_CONTEXT, SqlContext from ..enums import Dialects from ..exceptions import QueryException from ..queries import Query, QueryBuilder @@ -14,6 +15,8 @@ class MSSQLQuery(Query): Defines a query class for use with Microsoft SQL Server. """ + SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.MSSQL) + @classmethod def _builder(cls, **kwargs: Any) -> "MSSQLQueryBuilder": return MSSQLQueryBuilder(**kwargs) @@ -23,7 +26,7 @@ class MSSQLQueryBuilder(QueryBuilder): QUERY_CLS = MSSQLQuery def __init__(self, **kwargs: Any) -> None: - super().__init__(dialect=Dialects.MSSQL, **kwargs) + super().__init__(**kwargs) self._top: int | None = None @builder @@ -45,44 +48,45 @@ def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return] # Overridden to provide a more domain-specific API for T-SQL users self._limit = cast(ValueWrapper, self.wrap_constant(limit)) - def _offset_sql(self, **kwargs) -> str: + def _offset_sql(self, ctx: SqlContext) -> str: order_by = "" if not self._orderbys: order_by = " ORDER BY (SELECT 0)" return order_by + " OFFSET {offset} ROWS".format( - offset=self._offset.get_sql(**kwargs) if self._offset is not None else 0 + offset=self._offset.get_sql(ctx) if self._offset is not None else 0 ) - def _limit_sql(self, **kwargs) -> str: + def _limit_sql(self, ctx: SqlContext) -> str: if self._limit is None: return "" - return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) + return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(ctx)) - def _apply_pagination(self, querystring: str, **kwargs) -> str: + def _apply_pagination(self, querystring: str, ctx: SqlContext) -> str: # Note: Overridden as MSSQL specifies offset before the fetch next limit if self._limit is not None or self._offset: # Offset has to be present if fetch next is specified in a MSSQL query - querystring += self._offset_sql(**kwargs) + querystring += self._offset_sql(ctx) if self._limit is not None: - querystring += self._limit_sql(**kwargs) + querystring += self._limit_sql(ctx) return querystring - def get_sql(self, *args: Any, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext | None = None) -> str: + if not ctx: + ctx = MSSQLQuery.SQL_CONTEXT # MSSQL does not support group by a field alias. # Note: set directly in kwargs as they are re-used down the tree in the case of subqueries! - kwargs["groupby_alias"] = False - return super().get_sql(*args, **kwargs) + ctx = ctx.copy(groupby_alias=False) + return super().get_sql(ctx) def _top_sql(self) -> str: return "TOP ({}) ".format(self._top) if self._top else "" - def _select_sql(self, **kwargs: Any) -> str: + def _select_sql(self, ctx: SqlContext) -> str: + ctx = ctx.copy(with_alias=True, subquery=True) return "SELECT {distinct}{top}{select}".format( top=self._top_sql(), distinct="DISTINCT " if self._distinct else "", - select=",".join( - term.get_sql(with_alias=True, subquery=True, **kwargs) for term in self._selects - ), + select=",".join(term.get_sql(ctx) for term in self._selects), ) diff --git a/pypika_tortoise/dialects/mysql.py b/pypika_tortoise/dialects/mysql.py index bb563f8..caf72fe 100644 --- a/pypika_tortoise/dialects/mysql.py +++ b/pypika_tortoise/dialects/mysql.py @@ -4,6 +4,7 @@ from datetime import time from typing import Any, cast +from ..context import DEFAULT_SQL_CONTEXT, SqlContext from ..enums import Dialects from ..queries import Query, QueryBuilder, Table from ..terms import ValueWrapper @@ -15,6 +16,8 @@ class MySQLQuery(Query): Defines a query class for use with MySQL. """ + SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.MYSQL, quote_char="`") + @classmethod def _builder(cls, **kwargs: Any) -> "MySQLQueryBuilder": return MySQLQueryBuilder(**kwargs) @@ -25,8 +28,8 @@ def load(cls, fp: str) -> "MySQLLoadQueryBuilder": class MySQLValueWrapper(ValueWrapper): - def get_value_sql(self, **kwargs: Any) -> str: - quote_char = kwargs.get("secondary_quote_char") or "" + def get_value_sql(self, ctx: SqlContext) -> str: + quote_char = ctx.secondary_quote_char or "" if isinstance(value := self.value, str): value = value.replace(quote_char, quote_char * 2) value = value.replace("\\", "\\\\") @@ -37,60 +40,54 @@ def get_value_sql(self, **kwargs: Any) -> str: elif isinstance(value, (dict, list)): value = format_quotes(json.dumps(value), quote_char) return value.replace("\\", "\\\\") - return super().get_value_sql(**kwargs) + return super().get_value_sql(ctx) class MySQLQueryBuilder(QueryBuilder): - QUOTE_CHAR = "`" QUERY_CLS = MySQLQuery def __init__(self, **kwargs: Any) -> None: super().__init__( - dialect=Dialects.MYSQL, wrapper_cls=MySQLValueWrapper, wrap_set_operation_queries=False, **kwargs, ) self._modifiers: list[str] = [] - def _on_conflict_sql(self, **kwargs: Any) -> str: - kwargs["alias_quote_char"] = ( - self.ALIAS_QUOTE_CHAR - if self.QUERY_ALIAS_QUOTE_CHAR is None - else self.QUERY_ALIAS_QUOTE_CHAR + def _on_conflict_sql(self, ctx: SqlContext) -> str: + ctx = ctx.copy( + as_keyword=True, ) - kwargs["as_keyword"] = True - querystring = format_alias_sql("", self.alias, **kwargs) - return querystring + return format_alias_sql("", self.alias, ctx) - def get_sql(self, **kwargs: Any) -> str: # type:ignore[override] - self._set_kwargs_defaults(kwargs) - querystring = super().get_sql(**kwargs) + def get_sql(self, ctx: SqlContext | None = None) -> str: + ctx = ctx or MySQLQuery.SQL_CONTEXT + querystring = super().get_sql(ctx) if querystring and self._update_table: if self._orderbys: - querystring += self._orderby_sql(**kwargs) + querystring += self._orderby_sql(ctx) if self._limit: - querystring += self._limit_sql() + querystring += self._limit_sql(ctx) return querystring - def _on_conflict_action_sql(self, **kwargs: Any) -> str: - kwargs.pop("with_namespace", None) + def _on_conflict_action_sql(self, ctx: SqlContext) -> str: + on_conflict_ctx = ctx.copy(with_namespace=False) if len(self._on_conflict_do_updates) > 0: updates = [] for field, value in self._on_conflict_do_updates: if value: updates.append( "{field}={value}".format( - field=field.get_sql(**kwargs), - value=value.get_sql(**kwargs), + field=field.get_sql(on_conflict_ctx), + value=value.get_sql(on_conflict_ctx), ) ) else: updates.append( "{field}={alias}.{value}".format( - field=field.get_sql(**kwargs), - alias=format_quotes(self.alias, self.QUOTE_CHAR), - value=field.get_sql(**kwargs), + field=field.get_sql(on_conflict_ctx), + alias=format_quotes(self.alias, ctx.quote_char), + value=field.get_sql(on_conflict_ctx), ) ) action_sql = " ON DUPLICATE KEY UPDATE {updates}".format(updates=",".join(updates)) @@ -107,23 +104,22 @@ def modifier(self, value: str) -> MySQLQueryBuilder: # type:ignore[return] """ self._modifiers.append(value) - def _select_sql(self, **kwargs: Any) -> str: + def _select_sql(self, ctx: SqlContext) -> str: """ Overridden function to generate the SELECT part of the SQL statement, with the addition of the a modifier if present. """ + ctx = ctx.copy(with_alias=True, subquery=True) return "SELECT {distinct}{modifier}{select}".format( distinct="DISTINCT " if self._distinct else "", modifier="{} ".format(" ".join(self._modifiers)) if self._modifiers else "", - select=",".join( - term.get_sql(with_alias=True, subquery=True, **kwargs) for term in self._selects - ), + select=",".join(term.get_sql(ctx) for term in self._selects), ) - def _insert_sql(self, **kwargs: Any) -> str: + def _insert_sql(self, ctx: SqlContext) -> str: insert_table = cast(Table, self._insert_table) return "INSERT {ignore}INTO {table}".format( - table=insert_table.get_sql(**kwargs), + table=insert_table.get_sql(ctx), ignore="IGNORE " if self._on_conflict_do_nothing else "", ) @@ -143,23 +139,26 @@ def load(self, fp: str) -> MySQLLoadQueryBuilder: # type:ignore[return] def into(self, table: str | Table) -> MySQLLoadQueryBuilder: # type:ignore[return] self._into_table = table if isinstance(table, Table) else Table(table) - def get_sql(self, *args: Any, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext | None = None) -> str: + if not ctx: + ctx = MySQLQuery.SQL_CONTEXT + querystring = "" if self._load_file and self._into_table: - querystring += self._load_file_sql(**kwargs) - querystring += self._into_table_sql(**kwargs) - querystring += self._options_sql(**kwargs) + querystring += self._load_file_sql(ctx) + querystring += self._into_table_sql(ctx) + querystring += self._options_sql(ctx) return querystring - def _load_file_sql(self, **kwargs: Any) -> str: + def _load_file_sql(self, ctx: SqlContext) -> str: return "LOAD DATA LOCAL INFILE '{}'".format(self._load_file) - def _into_table_sql(self, **kwargs: Any) -> str: + def _into_table_sql(self, ctx: SqlContext) -> str: table = cast(Table, self._into_table) - return " INTO TABLE `{}`".format(table.get_sql(**kwargs)) + return " INTO TABLE {}".format(table.get_sql(ctx)) - def _options_sql(self, **kwargs: Any) -> str: + def _options_sql(self, ctx: SqlContext) -> str: return " FIELDS TERMINATED BY ','" def __str__(self) -> str: diff --git a/pypika_tortoise/dialects/oracle.py b/pypika_tortoise/dialects/oracle.py index f4b064a..65dd22b 100644 --- a/pypika_tortoise/dialects/oracle.py +++ b/pypika_tortoise/dialects/oracle.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from typing import Any +from ..context import DEFAULT_SQL_CONTEXT, SqlContext from ..enums import Dialects from ..queries import Query, QueryBuilder @@ -9,31 +12,32 @@ class OracleQuery(Query): Defines a query class for use with Oracle. """ + SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.ORACLE, alias_quote_char='"') + @classmethod def _builder(cls, **kwargs: Any) -> "OracleQueryBuilder": return OracleQueryBuilder(**kwargs) class OracleQueryBuilder(QueryBuilder): - QUOTE_CHAR = '"' QUERY_CLS = OracleQuery - ALIAS_QUOTE_CHAR = '"' def __init__(self, **kwargs: Any) -> None: - super().__init__(dialect=Dialects.ORACLE, **kwargs) + super().__init__(**kwargs) - def get_sql(self, *args: Any, **kwargs: Any) -> str: - # Oracle does not support group by a field alias - # Note: set directly in kwargs as they are re-used down the tree in the case of subqueries! - kwargs["groupby_alias"] = False - return super().get_sql(*args, **kwargs) + def get_sql(self, ctx: SqlContext | None = None) -> str: + if not ctx: + ctx = OracleQuery.SQL_CONTEXT + # Oracle does not support group by a field alias. + ctx = ctx.copy(groupby_alias=False) + return super().get_sql(ctx) - def _offset_sql(self, **kwargs) -> str: + def _offset_sql(self, ctx: SqlContext) -> str: if self._offset is None: return "" - return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(**kwargs)) + return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(ctx)) - def _limit_sql(self, **kwargs) -> str: + def _limit_sql(self, ctx: SqlContext) -> str: if self._limit is None: return "" - return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs)) + return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(ctx)) diff --git a/pypika_tortoise/dialects/postgresql.py b/pypika_tortoise/dialects/postgresql.py index 36b10bd..2c94792 100644 --- a/pypika_tortoise/dialects/postgresql.py +++ b/pypika_tortoise/dialects/postgresql.py @@ -5,6 +5,7 @@ from copy import copy from typing import TYPE_CHECKING, Any +from ..context import DEFAULT_SQL_CONTEXT, SqlContext from ..enums import Dialects from ..exceptions import QueryException from ..queries import Query, QueryBuilder @@ -23,17 +24,18 @@ class PostgreSQLQuery(Query): Defines a query class for use with PostgreSQL. """ + SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.POSTGRESQL, alias_quote_char='"') + @classmethod def _builder(cls, **kwargs) -> "PostgreSQLQueryBuilder": return PostgreSQLQueryBuilder(**kwargs) class PostgreSQLQueryBuilder(QueryBuilder): - ALIAS_QUOTE_CHAR = '"' QUERY_CLS = PostgreSQLQuery def __init__(self, **kwargs: Any) -> None: - super().__init__(dialect=Dialects.POSTGRESQL, **kwargs) + super().__init__(**kwargs) self._returns: list[Term] = [] self._return_star = False @@ -53,14 +55,13 @@ def distinct_on(self, *fields: str | Term) -> "PostgreSQLQueryBuilder": # type: elif isinstance(field, Term): self._distinct_on.append(field) - def _distinct_sql(self, **kwargs: Any) -> str: + def _distinct_sql(self, ctx: SqlContext) -> str: + distinct_ctx = ctx.copy(with_alias=True) if self._distinct_on: return "DISTINCT ON({distinct_on}) ".format( - distinct_on=",".join( - term.get_sql(with_alias=True, **kwargs) for term in self._distinct_on - ) + distinct_on=",".join(term.get_sql(distinct_ctx) for term in self._distinct_on) ) - return super()._distinct_sql(**kwargs) + return super()._distinct_sql(distinct_ctx) @builder def returning(self, *terms: Any) -> "PostgreSQLQueryBuilder": # type:ignore[return] @@ -74,9 +75,7 @@ def returning(self, *terms: Any) -> "PostgreSQLQueryBuilder": # type:ignore[ret elif isinstance(term, Function): raise QueryException("Aggregate functions are not allowed in returning") else: - self._return_other( - self.wrap_constant(term, self._wrapper_cls) # type:ignore[arg-type] - ) + self._return_other(self.wrap_constant(term, self._wrapper_cls)) def _validate_returning_term(self, term: Term) -> None: for field in term.fields_(): @@ -135,13 +134,13 @@ def _return_other(self, function: Term) -> None: self._validate_returning_term(function) self._returns.append(function) - def _returning_sql(self, **kwargs: Any) -> str: + def _returning_sql(self, ctx: SqlContext) -> str: + returning_ctx = ctx.copy(with_alias=True) return " RETURNING {returning}".format( - returning=",".join(term.get_sql(with_alias=True, **kwargs) for term in self._returns), + returning=",".join(term.get_sql(returning_ctx) for term in self._returns), ) - def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: Any) -> str: - self._set_kwargs_defaults(kwargs) + def get_sql(self, ctx: SqlContext | None = None) -> str: if not (self._selects or self._insert_table or self._delete_from or self._update_table): return "" if self._insert_table and not (self._selects or self._values): @@ -155,43 +154,46 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An has_reference_to_foreign_table = self._foreign_table has_update_from = self._update_table and self._from - kwargs["with_namespace"] = any( - [ - has_joins, - has_multiple_from_clauses, - has_subquery_from_clause, - has_reference_to_foreign_table, - has_update_from, - ] + ctx = ctx or PostgreSQLQuery.SQL_CONTEXT + ctx = ctx.copy( + with_namespace=any( + [ + has_joins, + has_multiple_from_clauses, + has_subquery_from_clause, + has_reference_to_foreign_table, + has_update_from, + ] + ), ) if self._update_table: if self._with: - querystring = self._with_sql(**kwargs) + querystring = self._with_sql(ctx) else: querystring = "" - querystring += self._update_sql(**kwargs) + querystring += self._update_sql(ctx) - querystring += self._set_sql(**kwargs) + querystring += self._set_sql(ctx) if self._joins: self._from.append(self._update_table.as_(self._update_table.get_table_name() + "_")) if self._from: - querystring += self._from_sql(**kwargs) + querystring += self._from_sql(ctx) if self._joins: - querystring += " " + " ".join(join.get_sql(**kwargs) for join in self._joins) + querystring += " " + " ".join(join.get_sql(ctx) for join in self._joins) if self._wheres: - querystring += self._where_sql(**kwargs) + querystring += self._where_sql(ctx) if self._orderbys: - querystring += self._orderby_sql(**kwargs) + querystring += self._orderby_sql(ctx) if self._limit: - querystring += self._limit_sql() + querystring += self._limit_sql(ctx) else: - querystring = super().get_sql(with_alias, subquery, **kwargs) + querystring = super().get_sql(ctx) if self._returns: - kwargs["with_namespace"] = self._update_table and self.from_ - querystring += self._returning_sql(**kwargs) + returning_ctx = ctx.copy(with_namespace=self._update_table and self.from_) + querystring += self._returning_sql(returning_ctx) return querystring diff --git a/pypika_tortoise/dialects/sqlite.py b/pypika_tortoise/dialects/sqlite.py index f2249b8..38896f4 100644 --- a/pypika_tortoise/dialects/sqlite.py +++ b/pypika_tortoise/dialects/sqlite.py @@ -2,16 +2,17 @@ from typing import Any +from ..context import DEFAULT_SQL_CONTEXT, SqlContext from ..enums import Dialects from ..queries import Query, QueryBuilder from ..terms import ValueWrapper class SQLLiteValueWrapper(ValueWrapper): - def get_value_sql(self, **kwargs: Any) -> str: + def get_value_sql(self, ctx: SqlContext) -> str: if isinstance(self.value, bool): return "1" if self.value else "0" - return super().get_value_sql(**kwargs) + return super().get_value_sql(ctx) class SQLLiteQuery(Query): @@ -19,6 +20,8 @@ class SQLLiteQuery(Query): Defines a query class for use with Microsoft SQL Server. """ + SQL_CONTEXT = DEFAULT_SQL_CONTEXT.copy(dialect=Dialects.SQLITE) + @classmethod def _builder(cls, **kwargs: Any) -> "SQLLiteQueryBuilder": return SQLLiteQueryBuilder(**kwargs) @@ -28,10 +31,10 @@ class SQLLiteQueryBuilder(QueryBuilder): QUERY_CLS = SQLLiteQuery def __init__(self, **kwargs) -> None: - super().__init__(dialect=Dialects.SQLITE, wrapper_cls=SQLLiteValueWrapper, **kwargs) + super().__init__(wrapper_cls=SQLLiteValueWrapper, **kwargs) - def get_sql(self, **kwargs: Any) -> str: # type:ignore[override] - self._set_kwargs_defaults(kwargs) + def get_sql(self, ctx: SqlContext | None = None) -> str: + ctx = ctx or SQLLiteQuery.SQL_CONTEXT if not (self._selects or self._insert_table or self._delete_from or self._update_table): return "" if self._insert_table and not (self._selects or self._values): @@ -45,40 +48,42 @@ def get_sql(self, **kwargs: Any) -> str: # type:ignore[override] has_reference_to_foreign_table = self._foreign_table has_update_from = self._update_table and self._from - kwargs["with_namespace"] = any( - [ - has_joins, - has_multiple_from_clauses, - has_subquery_from_clause, - has_reference_to_foreign_table, - has_update_from, - ] + ctx = ctx.copy( + with_namespace=any( + [ + has_joins, + has_multiple_from_clauses, + has_subquery_from_clause, + has_reference_to_foreign_table, + has_update_from, + ] + ), ) if self._update_table: if self._with: - querystring = self._with_sql(**kwargs) + querystring = self._with_sql(ctx) else: querystring = "" - querystring += self._update_sql(**kwargs) + querystring += self._update_sql(ctx) - querystring += self._set_sql(**kwargs) + querystring += self._set_sql(ctx) if self._joins: self._from.append(self._update_table.as_(self._update_table.get_table_name() + "_")) if self._from: - querystring += self._from_sql(**kwargs) + querystring += self._from_sql(ctx) if self._joins: - querystring += " " + " ".join(join.get_sql(**kwargs) for join in self._joins) + querystring += " " + " ".join(join.get_sql(ctx) for join in self._joins) if self._wheres: - querystring += self._where_sql(**kwargs) + querystring += self._where_sql(ctx) if self._orderbys: - querystring += self._orderby_sql(**kwargs) + querystring += self._orderby_sql(ctx) if self._limit: - querystring += self._limit_sql() + querystring += self._limit_sql(ctx) else: - querystring = super().get_sql(**kwargs) + querystring = super().get_sql(ctx=ctx) return querystring diff --git a/pypika_tortoise/enums.py b/pypika_tortoise/enums.py index f524851..49e4c74 100644 --- a/pypika_tortoise/enums.py +++ b/pypika_tortoise/enums.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Any class Arithmetic(Enum): @@ -86,7 +85,7 @@ def __init__(self, name: str) -> None: def __call__(self, length: int) -> "SqlTypeLength": return SqlTypeLength(self.name, length) - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: "SqlContext") -> str: return "{name}".format(name=self.name) @@ -95,7 +94,7 @@ def __init__(self, name: str, length: int) -> None: self.name = name self.length = length - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: "SqlContext") -> str: return "{name}({length})".format(name=self.name, length=self.length) @@ -141,3 +140,6 @@ class JSONOperators(Enum): GET_TEXT_VALUE = "->>" GET_PATH_JSON_VALUE = "#>" GET_PATH_TEXT_VALUE = "#>>" + + +from .context import SqlContext # noqa: E402 diff --git a/pypika_tortoise/functions.py b/pypika_tortoise/functions.py index 186ecaa..ad5969e 100644 --- a/pypika_tortoise/functions.py +++ b/pypika_tortoise/functions.py @@ -8,6 +8,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any +from .context import SqlContext from .enums import SqlTypes from .terms import AggregateFunction, Function, Star, Term from .utils import builder @@ -25,8 +26,8 @@ def __init__(self, name: str, *args, **kwargs) -> None: super().__init__(name, *args, alias=alias) self._distinct = False - def get_function_sql(self, **kwargs) -> str: - s = super().get_function_sql(**kwargs) + def get_function_sql(self, ctx: SqlContext) -> str: + s = super().get_function_sql(ctx) n = len(self.name) + 1 if self._distinct: @@ -105,7 +106,7 @@ def __init__(self, term: Any, percentile: int | float | str, alias: str | None = super().__init__("APPROXIMATE_PERCENTILE", term, alias=alias) self.percentile = float(percentile) - def get_special_params_sql(self, **kwargs) -> str: + def get_special_params_sql(self, ctx: SqlContext) -> str: return f"USING PARAMETERS percentile={self.percentile}" @@ -115,9 +116,9 @@ def __init__(self, term: Any, as_type: Any, alias: str | None = None) -> None: super().__init__("CAST", term, alias=alias) self.as_type = as_type - def get_special_params_sql(self, **kwargs) -> str: + def get_special_params_sql(self, ctx: SqlContext) -> str: type_sql = ( - self.as_type.get_sql(**kwargs) + self.as_type.get_sql(ctx) if hasattr(self.as_type, "get_sql") else str(self.as_type).upper() ) @@ -130,7 +131,7 @@ def __init__(self, term: Any, encoding: Enum, alias: str | None = None) -> None: super().__init__("CONVERT", term, alias=alias) self.encoding = encoding - def get_special_params_sql(self, **kwargs) -> str: + def get_special_params_sql(self, ctx: SqlContext) -> str: return "USING {type}".format(type=self.encoding.value) @@ -275,7 +276,7 @@ class CurTimestamp(Function): def __init__(self, alias: str | None = None) -> None: super().__init__("CURRENT_TIMESTAMP", alias=alias) - def get_function_sql(self, **kwargs) -> str: + def get_function_sql(self, ctx: SqlContext) -> str: # CURRENT_TIMESTAMP takes no arguments, so the SQL to generate is quite # simple. Note that empty parentheses have been omitted intentionally. return "CURRENT_TIMESTAMP" @@ -296,8 +297,8 @@ def __init__(self, date_part: Any, field: Term, alias: str | None = None) -> Non super().__init__("EXTRACT", date_part, alias=alias) self.field = field - def get_special_params_sql(self, **kwargs) -> str: - return "FROM {field}".format(field=self.field.get_sql(**kwargs)) + def get_special_params_sql(self, ctx: SqlContext) -> str: + return "FROM {field}".format(field=self.field.get_sql(ctx)) # Null Functions diff --git a/pypika_tortoise/queries.py b/pypika_tortoise/queries.py index bd400be..429d62e 100644 --- a/pypika_tortoise/queries.py +++ b/pypika_tortoise/queries.py @@ -5,6 +5,7 @@ from functools import reduce from typing import TYPE_CHECKING, Any, Sequence, Type, cast, overload +from .context import DEFAULT_SQL_CONTEXT, SqlContext from .enums import Dialects, JoinType, SetOperation from .exceptions import JoinException, QueryException, RollupException, SetOperationException from .terms import ( @@ -59,7 +60,7 @@ def __getitem__(self, name: str) -> Field: def get_table_name(self) -> str: return self.alias - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: raise NotImplementedError() @@ -73,10 +74,10 @@ def __init__( self.name = name self.query = query - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: if self.query is None: return self.name - return self.query.get_sql(**kwargs) + return self.query.get_sql(ctx) def __eq__(self, other: Any) -> bool: return isinstance(other, AliasedQuery) and self.name == other.name @@ -111,13 +112,13 @@ def __ne__(self, other: Any) -> bool: def __getattr__(self, item: str) -> "Table": return Table(item, schema=self) - def get_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: # FIXME escape - schema_sql = format_quotes(self._name, quote_char) + schema_sql = format_quotes(self._name, ctx.quote_char) if self._parent is not None: return "{parent}.{schema}".format( - parent=self._parent.get_sql(quote_char=quote_char, **kwargs), + parent=self._parent.get_sql(ctx), schema=schema_sql, ) @@ -178,26 +179,23 @@ def __init__( def get_table_name(self) -> str: return self.alias or self._table_name - def get_sql(self, **kwargs: Any) -> str: - quote_char = kwargs.get("quote_char") + def get_sql(self, ctx: SqlContext) -> str: # FIXME escape - table_sql = format_quotes(self._table_name, quote_char) + table_sql = format_quotes(self._table_name, ctx.quote_char) if self._schema is not None: - table_sql = "{schema}.{table}".format( - schema=self._schema.get_sql(**kwargs), table=table_sql - ) + table_sql = "{schema}.{table}".format(schema=self._schema.get_sql(ctx), table=table_sql) if self._for: table_sql = "{table} FOR {criterion}".format( - table=table_sql, criterion=self._for.get_sql(**kwargs) + table=table_sql, criterion=self._for.get_sql(ctx) ) elif self._for_portion: table_sql = "{table} FOR PORTION OF {criterion}".format( - table=table_sql, criterion=self._for_portion.get_sql(**kwargs) + table=table_sql, criterion=self._for_portion.get_sql(ctx) ) - return format_alias_sql(table_sql, self.alias, **kwargs) + return format_alias_sql(table_sql, self.alias, ctx) @builder def for_(self, temporal_criterion: Criterion) -> "Self": # type:ignore[return] @@ -216,7 +214,7 @@ def for_portion(self, period_criterion: PeriodCriterion) -> "Self": # type:igno self._for_portion = period_criterion def __str__(self) -> str: - return self.get_sql(quote_char='"') + return self.get_sql(DEFAULT_SQL_CONTEXT) def __eq__(self, other: Any) -> bool: return ( @@ -310,33 +308,29 @@ def __init__( default if default is None or isinstance(default, Term) else ValueWrapper(default) ) - def get_name_sql(self, **kwargs: Any) -> str: - quote_char = kwargs.get("quote_char") - + def get_name_sql(self, ctx: SqlContext) -> str: column_sql = "{name}".format( - name=format_quotes(self.name, quote_char), + name=format_quotes(self.name, ctx.quote_char), ) return column_sql - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: column_sql = "{name}{type}{nullable}{default}".format( - name=self.get_name_sql(**kwargs), + name=self.get_name_sql(ctx), type=" {}".format(self.type) if self.type else "", nullable=( " {}".format("NULL" if self.nullable else "NOT NULL") if self.nullable is not None else "" ), - default=( - " {}".format("DEFAULT " + self.default.get_sql(**kwargs)) if self.default else "" - ), + default=(" {}".format("DEFAULT " + self.default.get_sql(ctx)) if self.default else ""), ) return column_sql def __str__(self) -> str: - return self.get_sql(quote_char='"') + return self.get_sql(DEFAULT_SQL_CONTEXT) def make_columns(*names: tuple[str, str] | str) -> list[Column]: @@ -369,13 +363,11 @@ def __init__( ) self.end_column = end_column if isinstance(end_column, Column) else Column(end_column) - def get_sql(self, **kwargs: Any) -> str: - quote_char = kwargs.get("quote_char") - + def get_sql(self, ctx: SqlContext) -> str: period_for_sql = "PERIOD FOR {name} ({start_column_name},{end_column_name})".format( - name=format_quotes(self.name, quote_char), - start_column_name=self.start_column.get_name_sql(**kwargs), - end_column_name=self.end_column.get_name_sql(**kwargs), + name=format_quotes(self.name, ctx.quote_char), + start_column_name=self.start_column.get_name_sql(ctx), + end_column_name=self.end_column.get_name_sql(ctx), ) return period_for_sql @@ -394,6 +386,8 @@ class Query: This class is immutable. """ + SQL_CONTEXT: SqlContext = DEFAULT_SQL_CONTEXT + @classmethod def _builder(cls, **kwargs: Any) -> "QueryBuilder": return QueryBuilder(**kwargs) @@ -597,25 +591,23 @@ def __sub__(self, other: "QueryBuilder") -> "Self": # type:ignore[override] return self.minus(other) def __str__(self) -> str: - return self.get_sql() + return self.get_sql(DEFAULT_SQL_CONTEXT) - def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: set_operation_template = " {type} {query_string}" - kwargs.setdefault("dialect", self.base_query.dialect) - # This initializes the quote char based on the base query, which could be a dialect specific query class - # This might be overridden if quote_char is set explicitly in kwargs - kwargs.setdefault("quote_char", self.base_query.QUOTE_CHAR) - - base_querystring = self.base_query.get_sql( - subquery=self.base_query.wrap_set_operation_queries, **kwargs + # Default to the base query's dialect and quote_char + ctx = ctx.copy( + dialect=self.base_query.dialect, + quote_char=self.base_query.QUERY_CLS.SQL_CONTEXT.quote_char, + parameterizer=ctx.parameterizer, ) + set_ctx = ctx.copy(subquery=self.base_query.wrap_set_operation_queries) + base_querystring = self.base_query.get_sql(set_ctx) querystring = base_querystring for set_operation, set_operation_query in self._set_operation: - set_operation_querystring = set_operation_query.get_sql( - subquery=self.base_query.wrap_set_operation_queries, **kwargs - ) + set_operation_querystring = set_operation_query.get_sql(set_ctx) if len(self.base_query._selects) != len(set_operation_query._selects): raise SetOperationException( @@ -630,24 +622,24 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An ) if self._orderbys: - querystring += self._orderby_sql(**kwargs) + querystring += self._orderby_sql(ctx) - querystring += self._limit_sql(**kwargs) - querystring += self._offset_sql(**kwargs) + querystring += self._limit_sql(ctx) + querystring += self._offset_sql(ctx) - if subquery: - querystring = "({query})".format(query=querystring, **kwargs) + if ctx.subquery: + querystring = "({query})".format(query=querystring) - if with_alias: + if ctx.with_alias: return format_alias_sql( querystring, self.alias or self._table_name, # type:ignore[arg-type] - **kwargs, + ctx, ) return querystring - def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: + def _orderby_sql(self, ctx: SqlContext) -> str: """ Produces the ORDER BY part of the query. This is a list of fields and possibly their directionality, ASC or DESC. The clauses are stored in the query under self._orderbys as a list of tuples containing the field and @@ -660,9 +652,9 @@ def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: selected_aliases = {s.alias for s in self.base_query._selects} for field, directionality in self._orderbys: term = ( - format_quotes(field.alias, quote_char) + format_quotes(field.alias, ctx.quote_char) if field.alias and field.alias in selected_aliases - else field.get_sql(quote_char=quote_char, **kwargs) + else field.get_sql(ctx) ) clauses.append( @@ -673,15 +665,15 @@ def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: return " ORDER BY {orderby}".format(orderby=",".join(clauses)) - def _offset_sql(self, **kwargs) -> str: + def _offset_sql(self, ctx: SqlContext) -> str: if self._offset is None: return "" - return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) + return " OFFSET {offset}".format(offset=self._offset.get_sql(ctx)) - def _limit_sql(self, **kwargs) -> str: + def _limit_sql(self, ctx: SqlContext) -> str: if self._limit is None: return "" - return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) + return " LIMIT {limit}".format(limit=self._limit.get_sql(ctx)) class QueryBuilder(Selectable, Term): # type:ignore[misc] @@ -690,19 +682,13 @@ class QueryBuilder(Selectable, Term): # type:ignore[misc] state to be branched immutably. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR: str | None = None - QUERY_ALIAS_QUOTE_CHAR: str | None = None QUERY_CLS = Query def __init__( self, - dialect: Dialects | None = None, wrap_set_operation_queries: bool = True, wrapper_cls: Type[ValueWrapper] = ValueWrapper, immutable: bool = True, - as_keyword: bool = False, ) -> None: super().__init__(None) # type:ignore[arg-type] @@ -747,8 +733,6 @@ def __init__( self._subquery_count = 0 self._foreign_table = False - self.dialect = dialect - self.as_keyword = as_keyword self.wrap_set_operation_queries = wrap_set_operation_queries self._wrapper_cls = wrapper_cls @@ -817,7 +801,7 @@ def _conflict_field_str(self, term: str) -> Field | None: return Field(term, table=self._insert_table) return None - def _on_conflict_sql(self, **kwargs: Any) -> str: + def _on_conflict_sql(self, ctx: SqlContext) -> str: if not self._on_conflict_do_nothing and len(self._on_conflict_do_updates) == 0: if not self._on_conflict_fields: return "" @@ -828,38 +812,40 @@ def _on_conflict_sql(self, **kwargs: Any) -> str: conflict_query = " ON CONFLICT" if self._on_conflict_fields: + on_conflict_ctx = ctx.copy(with_alias=True) fields = [ - f.get_sql(with_alias=True, **kwargs) # type:ignore[union-attr] + f.get_sql(on_conflict_ctx) # type:ignore[union-attr] for f in self._on_conflict_fields ] conflict_query += " (" + ", ".join(fields) + ")" if self._on_conflict_wheres: + where_ctx = ctx.copy(subquery=True) conflict_query += " WHERE {where}".format( - where=self._on_conflict_wheres.get_sql(subquery=True, **kwargs) + where=self._on_conflict_wheres.get_sql(where_ctx) ) return conflict_query - def _on_conflict_action_sql(self, **kwargs: Any) -> str: - kwargs.pop("with_namespace", None) + def _on_conflict_action_sql(self, ctx: SqlContext) -> str: + ctx = ctx.copy(with_namespace=False) if self._on_conflict_do_nothing: return " DO NOTHING" elif len(self._on_conflict_do_updates) > 0: updates = [] + value_ctx = ctx.copy(with_namespace=True) for field, value in self._on_conflict_do_updates: if value: updates.append( "{field}={value}".format( - field=field.get_sql(**kwargs), - value=value.get_sql(with_namespace=True, **kwargs), + field=field.get_sql(ctx), value=value.get_sql(value_ctx) ) ) else: updates.append( "{field}=EXCLUDED.{value}".format( - field=field.get_sql(**kwargs), - value=field.get_sql(**kwargs), + field=field.get_sql(ctx), + value=field.get_sql(ctx), ) ) action_sql = " DO UPDATE SET {updates}".format(updates=",".join(updates)) # nosec:B608 @@ -867,7 +853,7 @@ def _on_conflict_action_sql(self, **kwargs: Any) -> str: if self._on_conflict_do_update_wheres: action_sql += " WHERE {where}".format( where=self._on_conflict_do_update_wheres.get_sql( - subquery=True, with_namespace=True, **kwargs + ctx.copy(subquery=True, with_namespace=True) ) ) return action_sql @@ -944,9 +930,7 @@ def replace_table( # type:ignore[return] self._wheres.replace_table(current_table, new_table) if self._wheres else None ) self._prewheres = ( - self._prewheres.replace_table(current_table, new_table) # type:ignore[assignment] - if self._prewheres - else None + self._prewheres.replace_table(current_table, new_table) if self._prewheres else None ) self._groupbys = [ groupby.replace_table(current_table, new_table) for groupby in self._groupbys @@ -1090,7 +1074,7 @@ def prewhere(self, criterion: Criterion) -> "Self": # type:ignore[return] if self._prewheres: self._prewheres &= criterion else: - self._prewheres = criterion # type:ignore[assignment] + self._prewheres = criterion @builder def where(self, criterion: Term | EmptyCriterion) -> "Self": # type:ignore[return] @@ -1100,7 +1084,7 @@ def where(self, criterion: Term | EmptyCriterion) -> "Self": # type:ignore[retu if not self._validate_table(criterion): self._foreign_table = True if self._wheres: - self._wheres &= criterion # type:ignore[assignment,operator] + self._wheres &= criterion # type:ignore[operator] else: self._wheres = criterion else: @@ -1122,18 +1106,18 @@ def where(self, criterion: Term | EmptyCriterion) -> "Self": # type:ignore[retu @builder def having(self, criterion: Criterion) -> "Self": # type:ignore[return] if self._havings: - self._havings &= criterion # type:ignore[operator] + self._havings &= criterion else: - self._havings = criterion # type:ignore[assignment] + self._havings = criterion @builder def groupby(self, *terms: str | int | Term) -> "Self": # type:ignore[return] for term in terms: if isinstance(term, str): - term = Field(term, table=self._from[0]) # type:ignore[assignment] + term = Field(term, table=self._from[0]) elif isinstance(term, int): field = Field(str(term), table=self._from[0]) - term = field.wrap_constant(term) # type:ignore[assignment] + term = field.wrap_constant(term) self._groupbys.append(term) # type:ignore[arg-type] @@ -1167,7 +1151,7 @@ def rollup( # type:ignore[return] elif 0 < len(self._groupbys) and isinstance(self._groupbys[-1], Rollup): # If a rollup was added last, then append the new terms to the previous rollup - self._groupbys[-1].args += terms # type:ignore[arg-type] + self._groupbys[-1].args += terms else: self._groupbys.append(Rollup(*terms)) # type:ignore[arg-type] @@ -1288,8 +1272,8 @@ def __getitem__(self, item: Any) -> Self | Field: # type:ignore[override] return self.slice(item) @staticmethod - def _list_aliases(field_set: Sequence[Field], quote_char: str | None = None) -> list[str]: - return [field.alias or field.get_sql(quote_char=quote_char) for field in field_set] + def _list_aliases(field_set: Sequence[Field], ctx: SqlContext) -> list[str]: + return [field.alias or field.get_sql(ctx) for field in field_set] def _select_field_str(self, term: str) -> None: if 0 == len(self._from): @@ -1387,7 +1371,7 @@ def _validate_terms_and_append(self, *terms: Any) -> None: ) def __str__(self) -> str: - return self.get_sql(dialect=self.dialect) + return self.get_sql(self.QUERY_CLS.SQL_CONTEXT) def __repr__(self) -> str: return self.__str__() @@ -1401,15 +1385,10 @@ def __ne__(self, other: Any) -> bool: # type:ignore[override] def __hash__(self) -> int: return hash(self.alias) + sum(hash(clause) for clause in self._from) - def _set_kwargs_defaults(self, kwargs: dict) -> None: - kwargs.setdefault("quote_char", self.QUOTE_CHAR) - kwargs.setdefault("secondary_quote_char", self.SECONDARY_QUOTE_CHAR) - kwargs.setdefault("alias_quote_char", self.ALIAS_QUOTE_CHAR) - kwargs.setdefault("as_keyword", self.as_keyword) - kwargs.setdefault("dialect", self.dialect) + def get_sql(self, ctx: SqlContext | None = None) -> str: + if not ctx: + ctx = self.QUERY_CLS.SQL_CONTEXT - def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: Any) -> str: - self._set_kwargs_defaults(kwargs) if not (self._selects or self._insert_table or self._delete_from or self._update_table): return "" if self._insert_table and not (self._selects or self._values): @@ -1423,166 +1402,172 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An has_reference_to_foreign_table = self._foreign_table has_update_from = self._update_table and self._from - kwargs["with_namespace"] = any( - [ - has_joins, - has_multiple_from_clauses, - has_subquery_from_clause, - has_reference_to_foreign_table, - has_update_from, - ] + ctx = ctx.copy( + with_namespace=any( + [ + has_joins, + has_multiple_from_clauses, + has_subquery_from_clause, + has_reference_to_foreign_table, + has_update_from, + ] + ) ) if self._update_table: if self._with: - querystring = self._with_sql(**kwargs) + querystring = self._with_sql(ctx) else: querystring = "" - querystring += self._update_sql(**kwargs) + querystring += self._update_sql(ctx) if self._joins: - querystring += " " + " ".join(join.get_sql(**kwargs) for join in self._joins) + querystring += " " + " ".join(join.get_sql(ctx) for join in self._joins) - querystring += self._set_sql(**kwargs) + querystring += self._set_sql(ctx) if self._from: - querystring += self._from_sql(**kwargs) + querystring += self._from_sql(ctx) if self._wheres: - querystring += self._where_sql(**kwargs) + querystring += self._where_sql(ctx) return querystring if self._delete_from: - querystring = self._delete_sql(**kwargs) + querystring = self._delete_sql(ctx) elif not self._select_into and self._insert_table: if self._with: - querystring = self._with_sql(**kwargs) + querystring = self._with_sql(ctx) else: querystring = "" if self._replace: - querystring += self._replace_sql(**kwargs) + querystring += self._replace_sql(ctx) else: - querystring += self._insert_sql(**kwargs) + querystring += self._insert_sql(ctx) if self._columns: - querystring += self._columns_sql(**kwargs) + querystring += self._columns_sql(ctx) if self._values: - querystring += self._values_sql(**kwargs) + querystring += self._values_sql(ctx) if self._on_conflict: - querystring += self._on_conflict_sql(**kwargs) - querystring += self._on_conflict_action_sql(**kwargs) + querystring += self._on_conflict_sql(ctx) + querystring += self._on_conflict_action_sql(ctx) return querystring else: - querystring += " " + self._select_sql(**kwargs) + querystring += " " + self._select_sql(ctx) else: if self._with: - querystring = self._with_sql(**kwargs) + querystring = self._with_sql(ctx) else: querystring = "" - querystring += self._select_sql(**kwargs) + querystring += self._select_sql(ctx) if self._insert_table: - querystring += self._into_sql(**kwargs) + querystring += self._into_sql(ctx) if self._from: - querystring += self._from_sql(**kwargs) + querystring += self._from_sql(ctx) if self._force_indexes: - querystring += self._force_index_sql(**kwargs) + querystring += self._force_index_sql(ctx) if self._use_indexes: - querystring += self._use_index_sql(**kwargs) + querystring += self._use_index_sql(ctx) if self._joins: - querystring += " " + " ".join(join.get_sql(**kwargs) for join in self._joins) + querystring += " " + " ".join(join.get_sql(ctx) for join in self._joins) if self._prewheres: - querystring += self._prewhere_sql(**kwargs) + querystring += self._prewhere_sql(ctx) if self._wheres: - querystring += self._where_sql(**kwargs) + querystring += self._where_sql(ctx) if self._groupbys: - querystring += self._group_sql(**kwargs) + querystring += self._group_sql(ctx) if self._mysql_rollup: querystring += self._rollup_sql() if self._havings: - querystring += self._having_sql(**kwargs) + querystring += self._having_sql(ctx) if self._orderbys: - querystring += self._orderby_sql(**kwargs) + querystring += self._orderby_sql(ctx) - querystring = self._apply_pagination(querystring, **kwargs) + querystring = self._apply_pagination(querystring, ctx) if self._for_update: - querystring += self._for_update_sql(**kwargs) + querystring += self._for_update_sql(ctx) - if subquery: + if ctx.subquery: querystring = "({query})".format(query=querystring) if self._on_conflict: - querystring += self._on_conflict_sql(**kwargs) - querystring += self._on_conflict_action_sql(**kwargs) - if with_alias: - kwargs["alias_quote_char"] = ( - self.ALIAS_QUOTE_CHAR - if self.QUERY_ALIAS_QUOTE_CHAR is None - else self.QUERY_ALIAS_QUOTE_CHAR - ) - return format_alias_sql(querystring, self.alias, **kwargs) + querystring += self._on_conflict_sql(ctx) + querystring += self._on_conflict_action_sql(ctx) + if ctx.with_alias: + return format_alias_sql(querystring, self.alias, ctx) return querystring - def _apply_pagination(self, querystring: str, **kwargs) -> str: - querystring += self._limit_sql(**kwargs) - querystring += self._offset_sql(**kwargs) + def _apply_pagination(self, querystring: str, ctx: SqlContext) -> str: + querystring += self._limit_sql(ctx) + querystring += self._offset_sql(ctx) return querystring - def _with_sql(self, **kwargs: Any) -> str: + def _with_sql(self, ctx: SqlContext) -> str: all_alias = [with_.alias for with_ in self._with] recursive = False for with_ in self._with: if with_.query.from_ in all_alias: # type:ignore[operator,union-attr] recursive = True break + + as_ctx = ctx.copy(subquery=False, with_alias=False) return f"WITH {'RECURSIVE ' if recursive else ''}" + ",".join( clause.alias + ( - "(" + ",".join([term.get_sql(**kwargs) for term in clause.terms]) + ")" + "(" + ",".join([term.get_sql(ctx) for term in clause.terms]) + ")" if clause.terms else "" ) + " AS (" - + clause.get_sql(subquery=False, with_alias=False, **kwargs) + + clause.get_sql(as_ctx) + ") " for clause in self._with ) - def get_parameterized_sql(self, **kwargs) -> tuple[str, list]: + def get_parameterized_sql(self, ctx: SqlContext | None = None) -> tuple[str, list]: """ Returns a tuple containing the query string and a list of parameters """ - parameterizer = kwargs.pop("parameterizer", Parameterizer()) + if not ctx: + ctx = self.QUERY_CLS.SQL_CONTEXT + + if not ctx.parameterizer: + ctx = ctx.copy(parameterizer=Parameterizer()) + return ( - self.get_sql(parameterizer=parameterizer, **kwargs), - parameterizer.values, + self.get_sql(ctx), + ctx.parameterizer.values, # type: ignore ) - def _distinct_sql(self, **kwargs: Any) -> str: + def _distinct_sql(self, ctx: SqlContext) -> str: return "DISTINCT " if self._distinct else "" - def _for_update_sql(self, **kwargs) -> str: + def _for_update_sql(self, ctx: SqlContext) -> str: if self._for_update: for_update = " FOR UPDATE" if self._for_update_of: - for_update += f' OF {", ".join([Table(item).get_sql(**kwargs) for item in self._for_update_of])}' + for_update += ( + f' OF {", ".join([Table(item).get_sql(ctx) for item in self._for_update_of])}' + ) if self._for_update_nowait: for_update += " NOWAIT" elif self._for_update_skip_locked: @@ -1592,88 +1577,80 @@ def _for_update_sql(self, **kwargs) -> str: return for_update - def _select_sql(self, **kwargs: Any) -> str: + def _select_sql(self, ctx: SqlContext) -> str: + select_ctx = ctx.copy(subquery=True, with_alias=True) return "SELECT {distinct}{select}".format( - distinct=self._distinct_sql(**kwargs), - select=",".join( - term.get_sql(with_alias=True, subquery=True, **kwargs) for term in self._selects - ), + distinct=self._distinct_sql(ctx), + select=",".join(term.get_sql(select_ctx) for term in self._selects), ) - def _insert_sql(self, **kwargs: Any) -> str: - table = self._insert_table.get_sql(**kwargs) # type:ignore[union-attr] + def _insert_sql(self, ctx: SqlContext) -> str: + table = self._insert_table.get_sql(ctx) # type:ignore[union-attr] return f"INSERT INTO {table}" - def _replace_sql(self, **kwargs: Any) -> str: - table = self._insert_table.get_sql(**kwargs) # type:ignore[union-attr] + def _replace_sql(self, ctx: SqlContext) -> str: + table = self._insert_table.get_sql(ctx) # type:ignore[union-attr] return f"REPLACE INTO {table}" @staticmethod - def _delete_sql(**kwargs: Any) -> str: + def _delete_sql(ctx: SqlContext) -> str: return "DELETE" - def _update_sql(self, **kwargs: Any) -> str: - table = self._update_table.get_sql(**kwargs) # type:ignore[union-attr] + def _update_sql(self, ctx: SqlContext) -> str: + table = self._update_table.get_sql(ctx) # type:ignore[union-attr] return f"UPDATE {table}" - def _columns_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: + def _columns_sql(self, ctx: SqlContext) -> str: """ SQL for Columns clause for INSERT queries - :param with_namespace: - Remove from kwargs, never format the column terms with namespaces since only one table can be inserted into """ - return " ({columns})".format( - columns=",".join(term.get_sql(with_namespace=False, **kwargs) for term in self._columns) - ) + # Remove from ctx, never format the column terms with namespaces since only one table can be inserted into + ctx = ctx.copy(with_namespace=False) + return " ({columns})".format(columns=",".join(term.get_sql(ctx) for term in self._columns)) - def _values_sql(self, **kwargs: Any) -> str: + def _values_sql(self, ctx: SqlContext) -> str: + values_ctx = ctx.copy(subquery=True, with_alias=True) return " VALUES ({values})".format( values="),(".join( - ",".join(term.get_sql(with_alias=True, subquery=True, **kwargs) for term in row) - for row in self._values + ",".join(term.get_sql(values_ctx) for term in row) for row in self._values ) ) - def _into_sql(self, **kwargs: Any) -> str: + def _into_sql(self, ctx: SqlContext) -> str: + into_ctx = ctx.copy(with_alias=False) return " INTO {table}".format( - table=self._insert_table.get_sql(with_alias=False, **kwargs), # type:ignore[union-attr] + table=self._insert_table.get_sql(into_ctx), # type:ignore[union-attr] ) - def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: + def _from_sql(self, ctx: SqlContext) -> str: + from_ctx = ctx.copy(subquery=True, with_alias=True) return " FROM {selectable}".format( - selectable=",".join( - clause.get_sql(subquery=True, with_alias=True, **kwargs) for clause in self._from - ) + selectable=",".join(clause.get_sql(from_ctx) for clause in self._from) ) - def _force_index_sql(self, **kwargs: Any) -> str: + def _force_index_sql(self, ctx: SqlContext) -> str: return " FORCE INDEX ({indexes})".format( - indexes=",".join(index.get_sql(**kwargs) for index in self._force_indexes), + indexes=",".join(index.get_sql(ctx) for index in self._force_indexes), ) - def _use_index_sql(self, **kwargs: Any) -> str: + def _use_index_sql(self, ctx: SqlContext) -> str: return " USE INDEX ({indexes})".format( - indexes=",".join(index.get_sql(**kwargs) for index in self._use_indexes), + indexes=",".join(index.get_sql(ctx) for index in self._use_indexes), ) - def _prewhere_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: + def _prewhere_sql(self, ctx: SqlContext) -> str: + prewhere_sql = ctx.copy(subquery=True) prewheres = cast(QueryBuilder, self._prewheres) - return " PREWHERE {prewhere}".format( - prewhere=prewheres.get_sql(quote_char=quote_char, subquery=True, **kwargs) - ) + return " PREWHERE {prewhere}".format(prewhere=prewheres.get_sql(prewhere_sql)) - def _where_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: + def _where_sql(self, ctx: SqlContext) -> str: + where_ctx = ctx.copy(subquery=True) wheres = cast(QueryBuilder, self._wheres) - return " WHERE {where}".format( - where=wheres.get_sql(quote_char=quote_char, subquery=True, **kwargs) - ) + return " WHERE {where}".format(where=wheres.get_sql(where_ctx)) def _group_sql( self, - quote_char: str | None = None, - alias_quote_char: str | None = None, - groupby_alias: bool = True, - **kwargs: Any, + ctx: SqlContext, ) -> str: """ Produces the GROUP BY part of the query. This is a list of fields. The clauses are stored in the query under @@ -1688,27 +1665,15 @@ def _group_sql( selected_aliases = {s.alias for s in self._selects} for field in self._groupbys: if (alias := field.alias) and alias in selected_aliases: - if groupby_alias: - clauses.append(format_quotes(alias, alias_quote_char or quote_char)) + if ctx.groupby_alias: + clauses.append(format_quotes(alias, ctx.alias_quote_char or ctx.quote_char)) else: for select in self._selects: if select.alias == alias: - clauses.append( - select.get_sql( - quote_char=quote_char, - alias_quote_char=alias_quote_char, - **kwargs, - ) - ) + clauses.append(select.get_sql(ctx)) break else: - clauses.append( - field.get_sql( - quote_char=quote_char, - alias_quote_char=alias_quote_char, - **kwargs, - ) - ) + clauses.append(field.get_sql(ctx)) sql = " GROUP BY {groupby}".format(groupby=",".join(clauses)) @@ -1719,10 +1684,7 @@ def _group_sql( def _orderby_sql( self, - quote_char: str | None = None, - alias_quote_char: str | None = None, - orderby_alias: bool = True, - **kwargs: Any, + ctx: SqlContext, ) -> str: """ Produces the ORDER BY part of the query. This is a list of fields and possibly their directionality, ASC or @@ -1738,11 +1700,9 @@ def _orderby_sql( selected_aliases = {s.alias for s in self._selects} for field, directionality in self._orderbys: term = ( - format_quotes(field.alias, alias_quote_char or quote_char) - if orderby_alias and field.alias and field.alias in selected_aliases - else field.get_sql( - quote_char=quote_char, alias_quote_char=alias_quote_char, **kwargs - ) + format_quotes(field.alias, ctx.alias_quote_char or ctx.quote_char) + if ctx.orderby_alias and field.alias and field.alias in selected_aliases + else field.get_sql(ctx) ) clauses.append( @@ -1756,26 +1716,27 @@ def _orderby_sql( def _rollup_sql(self) -> str: return " WITH ROLLUP" - def _having_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: - having = self._havings.get_sql(quote_char=quote_char, **kwargs) # type:ignore[union-attr] + def _having_sql(self, ctx: SqlContext) -> str: + having = self._havings.get_sql(ctx) # type:ignore[union-attr] return f" HAVING {having}" - def _offset_sql(self, **kwargs) -> str: + def _offset_sql(self, ctx: SqlContext) -> str: if self._offset is None: return "" - return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs)) + return " OFFSET {offset}".format(offset=self._offset.get_sql(ctx)) - def _limit_sql(self, **kwargs) -> str: + def _limit_sql(self, ctx: SqlContext) -> str: if self._limit is None: return "" - return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs)) + return " LIMIT {limit}".format(limit=self._limit.get_sql(ctx)) - def _set_sql(self, **kwargs: Any) -> str: + def _set_sql(self, ctx: SqlContext) -> str: + field_ctx = ctx.copy(with_namespace=False) return " SET {set}".format( set=",".join( "{field}={value}".format( - field=field.get_sql(**dict(kwargs, with_namespace=False)), - value=value.get_sql(**kwargs), + field=field.get_sql(field_ctx), + value=value.get_sql(ctx), ) for field, value in self._updates ) @@ -1852,9 +1813,10 @@ def __init__(self, item: Term, how: JoinType) -> None: self.item = item self.how = how - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: + join_ctx = ctx.copy(subquery=True, with_alias=True) sql = "JOIN {table}".format( - table=self.item.get_sql(subquery=True, with_alias=True, **kwargs), + table=self.item.get_sql(join_ctx), ) if self.how.value: @@ -1894,11 +1856,12 @@ def __init__( self.criterion = criteria self.collate = collate - def get_sql(self, **kwargs: Any) -> str: - join_sql = super().get_sql(**kwargs) + def get_sql(self, ctx: SqlContext) -> str: + join_sql = super().get_sql(ctx) + criterion_ctx = ctx.copy(subquery=True) return "{join} ON {criterion}{collate}".format( join=join_sql, - criterion=self.criterion.get_sql(subquery=True, **kwargs), + criterion=self.criterion.get_sql(criterion_ctx), collate=" COLLATE {}".format(self.collate) if self.collate else "", ) @@ -1939,11 +1902,11 @@ def __init__(self, item: Term, how: JoinType, fields: Sequence[Field]) -> None: super().__init__(item, how) self.fields = fields - def get_sql(self, **kwargs: Any) -> str: - join_sql = super().get_sql(**kwargs) + def get_sql(self, ctx: SqlContext) -> str: + join_sql = super().get_sql(ctx) return "{join} USING ({fields})".format( join=join_sql, - fields=",".join(field.get_sql(**kwargs) for field in self.fields), + fields=",".join(field.get_sql(ctx) for field in self.fields), ) def validate(self, _from: Sequence[Table], _joins: Sequence[Table]) -> None: @@ -1974,9 +1937,6 @@ class CreateQueryBuilder: Query builder used to build CREATE queries. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR: str | None = None QUERY_CLS = Query def __init__(self, dialect: Dialects | None = None) -> None: @@ -1992,11 +1952,6 @@ def __init__(self, dialect: Dialects | None = None) -> None: self._if_not_exists = False self.dialect = dialect - def _set_kwargs_defaults(self, kwargs: dict) -> None: - kwargs.setdefault("quote_char", self.QUOTE_CHAR) - kwargs.setdefault("secondary_quote_char", self.SECONDARY_QUOTE_CHAR) - kwargs.setdefault("dialect", self.dialect) - @builder def create_table(self, table: Table | str) -> "Self": # type:ignore[return] """ @@ -2160,14 +2115,14 @@ def as_select(self, query_builder: QueryBuilder) -> "Self": # type:ignore[retur def if_not_exists(self) -> "Self": # type:ignore[return] self._if_not_exists = True - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext | None) -> str: """ Gets the sql statement string. :return: The create table statement. :rtype: str """ - self._set_kwargs_defaults(kwargs) + ctx = ctx or self.QUERY_CLS.SQL_CONTEXT if not self._create_table: return "" @@ -2175,19 +2130,19 @@ def get_sql(self, **kwargs: Any) -> str: if not self._columns and not self._as_select: return "" - create_table = self._create_table_sql(**kwargs) + create_table = self._create_table_sql(ctx) if self._as_select: - return create_table + self._as_select_sql(**kwargs) + return create_table + self._as_select_sql(ctx) - body = self._body_sql(**kwargs) - table_options = self._table_options_sql(**kwargs) + body = self._body_sql(ctx) + table_options = self._table_options_sql(ctx) return "{create_table} ({body}){table_options}".format( create_table=create_table, body=body, table_options=table_options ) - def _create_table_sql(self, **kwargs: Any) -> str: + def _create_table_sql(self, ctx: SqlContext) -> str: table_type = "" if self._temporary: table_type = "TEMPORARY " @@ -2201,10 +2156,10 @@ def _create_table_sql(self, **kwargs: Any) -> str: return "CREATE {table_type}TABLE {if_not_exists}{table}".format( table_type=table_type, if_not_exists=if_not_exists, - table=self._create_table.get_sql(**kwargs), # type:ignore[attr-defined,union-attr] + table=self._create_table.get_sql(ctx), # type: ignore ) - def _table_options_sql(self, **kwargs) -> str: + def _table_options_sql(self, ctx: SqlContext) -> str: table_options = "" if self._with_system_versioning: @@ -2212,44 +2167,44 @@ def _table_options_sql(self, **kwargs) -> str: return table_options - def _column_clauses(self, **kwargs) -> list[str]: - return [column.get_sql(**kwargs) for column in self._columns] + def _column_clauses(self, ctx: SqlContext) -> list[str]: + return [column.get_sql(ctx) for column in self._columns] - def _period_for_clauses(self, **kwargs) -> list[str]: - return [period_for.get_sql(**kwargs) for period_for in self._period_fors] + def _period_for_clauses(self, ctx: SqlContext) -> list[str]: + return [period_for.get_sql(ctx) for period_for in self._period_fors] - def _unique_key_clauses(self, **kwargs) -> list[str]: + def _unique_key_clauses(self, ctx: SqlContext) -> list[str]: return [ "UNIQUE ({unique})".format( - unique=",".join(column.get_name_sql(**kwargs) for column in unique) + unique=",".join(column.get_name_sql(ctx) for column in unique) ) for unique in self._uniques ] - def _primary_key_clause(self, **kwargs) -> str: + def _primary_key_clause(self, ctx: SqlContext) -> str: columns = ",".join( - column.get_name_sql(**kwargs) for column in self._primary_key # type:ignore[union-attr] + column.get_name_sql(ctx) for column in self._primary_key # type:ignore[union-attr] ) return f"PRIMARY KEY ({columns})" - def _body_sql(self, **kwargs) -> str: - clauses = self._column_clauses(**kwargs) - clauses += self._period_for_clauses(**kwargs) - clauses += self._unique_key_clauses(**kwargs) + def _body_sql(self, ctx: SqlContext) -> str: + clauses = self._column_clauses(ctx) + clauses += self._period_for_clauses(ctx) + clauses += self._unique_key_clauses(ctx) # Primary keys if self._primary_key: - clauses.append(self._primary_key_clause(**kwargs)) + clauses.append(self._primary_key_clause(ctx)) return ",".join(clauses) - def _as_select_sql(self, **kwargs: Any) -> str: + def _as_select_sql(self, ctx: SqlContext) -> str: return " AS ({query})".format( - query=self._as_select.get_sql(**kwargs), # type:ignore[union-attr] + query=self._as_select.get_sql(ctx), # type:ignore[union-attr] ) def __str__(self) -> str: - return self.get_sql() + return self.get_sql(self.QUERY_CLS.SQL_CONTEXT) def __repr__(self) -> str: return self.__str__() @@ -2260,28 +2215,20 @@ class DropQueryBuilder: Query builder used to build DROP queries. """ - QUOTE_CHAR = '"' - SECONDARY_QUOTE_CHAR = "'" - ALIAS_QUOTE_CHAR: str | None = None + SQL_CONTEXT = DEFAULT_SQL_CONTEXT QUERY_CLS = Query - def __init__(self, dialect: Dialects | None = None) -> None: + def __init__(self) -> None: self._drop_table: Table | None = None self._if_exists: bool | None = None - self.dialect = dialect - - def _set_kwargs_defaults(self, kwargs: dict) -> None: - kwargs.setdefault("quote_char", self.QUOTE_CHAR) - kwargs.setdefault("secondary_quote_char", self.SECONDARY_QUOTE_CHAR) - kwargs.setdefault("dialect", self.dialect) - def get_sql(self, **kwargs: Any) -> str: - self._set_kwargs_defaults(kwargs) + def get_sql(self, ctx: SqlContext | None = None) -> str: + ctx = ctx or self.SQL_CONTEXT if not self._drop_table: return "" - querystring = self._drop_table_sql(**kwargs) + querystring = self._drop_table_sql(ctx) return querystring @@ -2296,16 +2243,16 @@ def drop_table(self, table: Table | str) -> "Self": # type:ignore[return] def if_exists(self) -> "Self": # type:ignore[return] self._if_exists = True - def _drop_table_sql(self, **kwargs: Any) -> str: + def _drop_table_sql(self, ctx: SqlContext) -> str: if_exists = "IF EXISTS " if self._if_exists else "" drop_table = cast(Table, self._drop_table) return "DROP TABLE {if_exists}{table}".format( if_exists=if_exists, - table=drop_table.get_sql(**kwargs), + table=drop_table.get_sql(ctx), ) def __str__(self) -> str: - return self.get_sql() + return self.get_sql(self.QUERY_CLS.SQL_CONTEXT) def __repr__(self) -> str: return self.__str__() diff --git a/pypika_tortoise/terms.py b/pypika_tortoise/terms.py index 1c8ce20..3271df7 100644 --- a/pypika_tortoise/terms.py +++ b/pypika_tortoise/terms.py @@ -9,6 +9,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, Type, TypeVar, cast +from .context import DEFAULT_SQL_CONTEXT, SqlContext from .enums import ( Arithmetic, Boolean, @@ -48,7 +49,7 @@ def find_(self, type: Type[NodeT]) -> list[NodeT]: class Term(Node): - is_aggregate: bool | None = False # type:ignore[assignment] + is_aggregate: bool | None = False def __init__(self, alias: str | None = None) -> None: self.alias = alias @@ -138,7 +139,7 @@ def notnull(self) -> "Not": return self.isnull().negate() def bitwiseand(self, value: int) -> "BitwiseAndCriterion": - return BitwiseAndCriterion(self, self.wrap_constant(value)) # type:ignore[arg-type] + return BitwiseAndCriterion(self, self.wrap_constant(value)) def gt(self, other: Any) -> "BasicCriterion": return self > other @@ -156,54 +157,34 @@ def ne(self, other: Any) -> "BasicCriterion": return self != other def glob(self, expr: str) -> "BasicCriterion": - return BasicCriterion( - Matching.glob, self, self.wrap_constant(expr) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.glob, self, self.wrap_constant(expr)) def like(self, expr: str) -> "BasicCriterion": - return BasicCriterion( - Matching.like, self, self.wrap_constant(expr) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.like, self, self.wrap_constant(expr)) def not_like(self, expr: str) -> "BasicCriterion": - return BasicCriterion( - Matching.not_like, self, self.wrap_constant(expr) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr)) def ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion( - Matching.ilike, self, self.wrap_constant(expr) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr)) def not_ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion( - Matching.not_ilike, self, self.wrap_constant(expr) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr)) def rlike(self, expr: str) -> "BasicCriterion": - return BasicCriterion( - Matching.rlike, self, self.wrap_constant(expr) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr)) def regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion( - Matching.regex, self, self.wrap_constant(pattern) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern)) def between(self, lower: Any, upper: Any) -> "BetweenCriterion": - return BetweenCriterion( - self, self.wrap_constant(lower), self.wrap_constant(upper) # type:ignore[arg-type] - ) + return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper)) def from_to(self, start: Any, end: Any) -> "PeriodCriterion": - return PeriodCriterion( - self, self.wrap_constant(start), self.wrap_constant(end) # type:ignore[arg-type] - ) + return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end)) def as_of(self, expr: str) -> "BasicCriterion": - return BasicCriterion( - Matching.as_of, self, self.wrap_constant(expr) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr)) def all_(self) -> "All": return All(self) @@ -217,9 +198,7 @@ def notin(self, arg: list | tuple | set | "Term") -> "ContainsCriterion": return self.isin(arg).negate() def bin_regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion( - Matching.bin_regex, self, self.wrap_constant(pattern) # type:ignore[arg-type] - ) + return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern)) def negate(self) -> "Not": return Not(self) @@ -234,24 +213,16 @@ def __neg__(self) -> "Negative": return Negative(self) def __add__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.add, self, self.wrap_constant(other) # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.add, self, self.wrap_constant(other)) def __sub__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.sub, self, self.wrap_constant(other) # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.sub, self, self.wrap_constant(other)) def __mul__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.mul, self, self.wrap_constant(other) # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.mul, self, self.wrap_constant(other)) def __truediv__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.div, self, self.wrap_constant(other) # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.div, self, self.wrap_constant(other)) def __pow__(self, other: Any) -> "Pow": return Pow(self, other) @@ -260,46 +231,34 @@ def __mod__(self, other: Any) -> "Mod": return Mod(self, other) def __radd__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.add, self.wrap_constant(other), self # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.add, self.wrap_constant(other), self) def __rsub__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.sub, self.wrap_constant(other), self # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.sub, self.wrap_constant(other), self) def __rmul__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.mul, self.wrap_constant(other), self # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.mul, self.wrap_constant(other), self) def __rtruediv__(self, other: Any) -> "ArithmeticExpression": - return ArithmeticExpression( - Arithmetic.div, self.wrap_constant(other), self # type:ignore[arg-type] - ) + return ArithmeticExpression(Arithmetic.div, self.wrap_constant(other), self) def __eq__(self, other: Any) -> "BasicCriterion": # type:ignore[override] - return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) # type:ignore[arg-type] + return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) def __ne__(self, other: Any) -> "BasicCriterion": # type:ignore[override] - return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) # type:ignore[arg-type] + return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) def __gt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) # type:ignore[arg-type] + return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) def __ge__(self, other: Any) -> "BasicCriterion": - return BasicCriterion( - Equality.gte, self, self.wrap_constant(other) # type:ignore[arg-type] - ) + return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) def __lt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) # type:ignore[arg-type] + return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) def __le__(self, other: Any) -> "BasicCriterion": - return BasicCriterion( - Equality.lte, self, self.wrap_constant(other) # type:ignore[arg-type] - ) + return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) def __getitem__(self, item: slice) -> "BetweenCriterion": if not isinstance(item, slice): @@ -307,12 +266,13 @@ def __getitem__(self, item: slice) -> "BetweenCriterion": return self.between(item.start, item.stop) def __str__(self) -> str: - return self.get_sql(quote_char='"', secondary_quote_char="'") + return self.get_sql(DEFAULT_SQL_CONTEXT) def __hash__(self) -> int: - return hash(self.get_sql(with_alias=True)) + ctx = DEFAULT_SQL_CONTEXT.copy(with_alias=True) + return hash(self.get_sql(ctx)) - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: raise NotImplementedError() @@ -345,12 +305,11 @@ def __init__(self, placeholder: str | None = None, idx: int | None = None) -> No self._placeholder = placeholder self._idx = idx - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: if self._placeholder: return self._placeholder - dialect = kwargs.get("dialect", None) - return self.IDX_PLACEHOLDERS.get(dialect, lambda _: self.DEFAULT_PLACEHOLDER)(self._idx) + return self.IDX_PLACEHOLDERS.get(ctx.dialect, lambda _: self.DEFAULT_PLACEHOLDER)(self._idx) class Parameterizer: @@ -366,7 +325,7 @@ class Parameterizer: >>> sql, parameterizer.values ('SELECT "id" FROM "customers" WHERE "lname"=?', ['Mustermann']) - Parameterizer remembers the values it has seen and replaces them with parameters. The values can + Parameterizer remembers the values it has seen and replaces them with parameters. The values can be accessed via the `values` attribute. """ @@ -397,10 +356,10 @@ def __init__(self, term: Term) -> None: @property def is_aggregate(self) -> bool | None: # type:ignore[override] - return self.term.is_aggregate # type:ignore[has-type] + return self.term.is_aggregate - def get_sql(self, **kwargs: Any) -> str: - return "-{term}".format(term=self.term.get_sql(**kwargs)) + def get_sql(self, ctx: SqlContext) -> str: + return "-{term}".format(term=self.term.get_sql(ctx)) class ValueWrapper(Term): @@ -424,29 +383,29 @@ def __init__( self.value = value self.allow_parametrize = allow_parametrize - def get_value_sql(self, **kwargs: Any) -> str: - return self.get_formatted_value(self.value, **kwargs) + def get_value_sql(self, ctx: SqlContext) -> str: + return self.get_formatted_value(self.value, ctx) @classmethod - def get_formatted_value(cls, value: Any, **kwargs) -> str: - quote_char = kwargs.get("secondary_quote_char") or "" + def get_formatted_value(cls, value: Any, ctx: SqlContext) -> str: + quote_char = ctx.secondary_quote_char or "" # FIXME escape values if isinstance(value, Term): - return value.get_sql(**kwargs) + return value.get_sql(ctx) if isinstance(value, Enum): if isinstance(value, DatePart): return value.value - return cls.get_formatted_value(value.value, **kwargs) + return cls.get_formatted_value(value.value, ctx) if isinstance(value, (date, time)): - return cls.get_formatted_value(value.isoformat(), **kwargs) + return cls.get_formatted_value(value.isoformat(), ctx) if isinstance(value, str): value = value.replace(quote_char, quote_char * 2) return format_quotes(value, quote_char) if isinstance(value, bool): return str(value).lower() if isinstance(value, uuid.UUID): - return cls.get_formatted_value(str(value), **kwargs) + return cls.get_formatted_value(str(value), ctx) if isinstance(value, (dict, list)): return format_quotes(json.dumps(value), quote_char) if value is None: @@ -455,25 +414,18 @@ def get_formatted_value(cls, value: Any, **kwargs) -> str: def get_sql( self, - quote_char: str | None = None, - secondary_quote_char: str = "'", - parameterizer: Parameterizer | None = None, - **kwargs: Any, + ctx: SqlContext, ) -> str: if ( - parameterizer is None - or not parameterizer.should_parameterize(self.value) + ctx.parameterizer is None + or not ctx.parameterizer.should_parameterize(self.value) or not self.allow_parametrize ): - sql = self.get_value_sql( - quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs - ) - return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + sql = self.get_value_sql(ctx) + return format_alias_sql(sql, self.alias, ctx) - param = parameterizer.create_param(self.value) - return format_alias_sql( - param.get_sql(**kwargs), self.alias, quote_char=quote_char, **kwargs - ) + param = ctx.parameterizer.create_param(self.value) + return format_alias_sql(param.get_sql(ctx), self.alias, ctx) class JSON(Term): @@ -510,22 +462,22 @@ def _get_list_sql(self, value: list, **kwargs: Any) -> str: def _get_str_sql(value: str, quote_char: str = '"', **kwargs: Any) -> str: return format_quotes(value, quote_char) - def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str: - sql = format_quotes(self._recursive_get_sql(self.value), secondary_quote_char) - return format_alias_sql(sql, self.alias, **kwargs) + def get_sql(self, ctx: SqlContext) -> str: + sql = format_quotes(self._recursive_get_sql(self.value), ctx.secondary_quote_char) + return format_alias_sql(sql, self.alias, ctx) def get_json_value(self, key_or_index: str | int) -> "BasicCriterion": return BasicCriterion( JSONOperators.GET_JSON_VALUE, self, - self.wrap_constant(key_or_index), # type:ignore[arg-type] + self.wrap_constant(key_or_index), ) def get_text_value(self, key_or_index: str | int) -> "BasicCriterion": return BasicCriterion( JSONOperators.GET_TEXT_VALUE, self, - self.wrap_constant(key_or_index), # type:ignore[arg-type] + self.wrap_constant(key_or_index), ) def get_path_json_value(self, path_json: str) -> "BasicCriterion": @@ -575,8 +527,8 @@ def __init__(self, field: str | "Field") -> None: super().__init__(None) self.field = Field(field) if not isinstance(field, Field) else field - def get_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: - return "VALUES({value})".format(value=self.field.get_sql(quote_char=quote_char, **kwargs)) + def get_sql(self, ctx: SqlContext) -> str: + return "VALUES({value})".format(value=self.field.get_sql(ctx)) class LiteralValue(Term): @@ -584,8 +536,8 @@ def __init__(self, value, alias: str | None = None) -> None: super().__init__(alias) self._value = value - def get_sql(self, **kwargs: Any) -> str: - return format_alias_sql(self._value, self.alias, **kwargs) + def get_sql(self, ctx: SqlContext) -> str: + return format_alias_sql(self._value, self.alias, ctx) class NullValue(LiteralValue): @@ -626,7 +578,7 @@ def all(terms: Iterable[Any] = ()) -> "EmptyCriterion": return crit - def get_sql(self, **kwargs) -> str: # type:ignore[override] + def get_sql(self, ctx: SqlContext) -> str: raise NotImplementedError() @@ -678,26 +630,22 @@ def replace_table( # type:ignore[return] A copy of the field with the tables replaced. """ if self.table == current_table: - self.table = new_table # type:ignore[assignment] - - def get_sql(self, **kwargs: Any) -> str: # type:ignore[override] - with_alias = kwargs.pop("with_alias", False) - with_namespace = kwargs.pop("with_namespace", False) - quote_char = kwargs.pop("quote_char", None) + self.table = new_table - field_sql = format_quotes(self.name, quote_char) + def get_sql(self, ctx: SqlContext) -> str: + field_sql = format_quotes(self.name, ctx.quote_char) # Need to add namespace if the table has an alias - if self.table and (with_namespace or self.table.alias): + if self.table and (ctx.with_namespace or self.table.alias): table_name = self.table.get_table_name() field_sql = "{namespace}.{name}".format( - namespace=format_quotes(table_name, quote_char), + namespace=format_quotes(table_name, ctx.quote_char), name=field_sql, ) field_alias = getattr(self, "alias", None) - if with_alias: - return format_alias_sql(field_sql, field_alias, quote_char=quote_char, **kwargs) + if ctx.with_alias: + return format_alias_sql(field_sql, field_alias, ctx) return field_sql @@ -706,8 +654,8 @@ def __init__(self, name: str, alias: str | None = None) -> None: super().__init__(alias) self.name = name - def get_sql(self, quote_char: str | None = None, **kwargs: Any) -> str: - return format_quotes(self.name, quote_char) + def get_sql(self, ctx: SqlContext) -> str: + return format_quotes(self.name, ctx.quote_char) class Star(Field): @@ -719,16 +667,10 @@ def nodes_(self) -> Iterator[NodeT]: if self.table is not None: yield from self.table.nodes_() - def get_sql( # type:ignore[override] - self, - with_alias: bool = False, - with_namespace: bool = False, - quote_char: str | None = None, - **kwargs: Any, - ) -> str: - if self.table and (with_namespace or self.table.alias): + def get_sql(self, ctx: SqlContext) -> str: + if self.table and (ctx.with_namespace or self.table.alias): namespace = self.table.alias or getattr(self.table, "_table_name") - return "{}.*".format(format_quotes(namespace, quote_char)) + return "{}.*".format(format_quotes(namespace, ctx.quote_char)) return "*" @@ -741,19 +683,15 @@ def __init__(self, *values: Any) -> None: def nodes_(self) -> Iterator[NodeT]: yield self # type:ignore[misc] for value in self.values: - yield from value.nodes_() # type:ignore[union-attr] + yield from value.nodes_() - def get_sql(self, **kwargs: Any) -> str: - sql = "({})".format( - ",".join(term.get_sql(**kwargs) for term in self.values) # type:ignore[union-attr] - ) - return format_alias_sql(sql, self.alias, **kwargs) + def get_sql(self, ctx: SqlContext) -> str: + sql = "({})".format(",".join(term.get_sql(ctx) for term in self.values)) + return format_alias_sql(sql, self.alias, ctx) @property def is_aggregate(self) -> bool | None: # type:ignore[override] - return resolve_is_aggregate( - [val.is_aggregate for val in self.values] # type:ignore[has-type,union-attr] - ) + return resolve_is_aggregate([val.is_aggregate for val in self.values]) @builder def replace_table( # type:ignore[return] @@ -769,10 +707,7 @@ def replace_table( # type:ignore[return] :return: A copy of the field with the tables replaced. """ - self.values = [ - value.replace_table(current_table, new_table) # type:ignore[misc,union-attr] - for value in self.values - ] + self.values = [value.replace_table(current_table, new_table) for value in self.values] class Array(Tuple): @@ -780,21 +715,20 @@ def __init__(self, *values: Any) -> None: super().__init__(*values) self.original_value = list(values) - def get_sql(self, parameterizer: Parameterizer | None = None, **kwargs: Any) -> str: - if parameterizer is None or not parameterizer.should_parameterize(self.original_value): - dialect = kwargs.get("dialect", None) - values = ",".join( - term.get_sql(**kwargs) for term in self.values - ) # type:ignore[union-attr] + def get_sql(self, ctx: SqlContext) -> str: + if ctx.parameterizer is None or not ctx.parameterizer.should_parameterize( + self.original_value + ): + values = ",".join(term.get_sql(ctx) for term in self.values) sql = "[{}]".format(values) - if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT): + if ctx.dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT): sql = "ARRAY[{}]".format(values) if len(values) > 0 else "'{}'" - return format_alias_sql(sql, self.alias, **kwargs) + return format_alias_sql(sql, self.alias, ctx) - param = parameterizer.create_param(self.original_value) - return param.get_sql(**kwargs) + param = ctx.parameterizer.create_param(self.original_value) + return param.get_sql(ctx) class Bracket(Tuple): @@ -849,17 +783,17 @@ def replace_table( # type:ignore[return] self.right = self.right.replace_table(current_table, new_table) self.nested = self.right.replace_table(current_table, new_table) - def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: sql = "{left}{comparator}{right}{nested_comparator}{nested}".format( - left=self.left.get_sql(**kwargs), + left=self.left.get_sql(ctx), comparator=self.comparator.value, - right=self.right.get_sql(**kwargs), + right=self.right.get_sql(ctx), nested_comparator=self.nested_comparator.value, # type:ignore[attr-defined] - nested=self.nested.get_sql(**kwargs), + nested=self.nested.get_sql(ctx), ) - if with_alias: - return format_alias_sql(sql=sql, alias=self.alias, **kwargs) + if ctx.with_alias: + return format_alias_sql(sql=sql, alias=self.alias, ctx=ctx) return sql @@ -897,7 +831,7 @@ def nodes_(self) -> Iterator[NodeT]: @property def is_aggregate(self) -> bool | None: # type:ignore[override] - aggrs = [term.is_aggregate for term in (self.left, self.right)] # type:ignore[has-type] + aggrs = [term.is_aggregate for term in (self.left, self.right)] return resolve_is_aggregate(aggrs) @builder @@ -917,14 +851,14 @@ def replace_table( # type:ignore[return] self.left = self.left.replace_table(current_table, new_table) self.right = self.right.replace_table(current_table, new_table) - def get_sql(self, quote_char: str = '"', with_alias: bool = False, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: sql = "{left}{comparator}{right}".format( comparator=self.comparator.value, - left=self.left.get_sql(quote_char=quote_char, **kwargs), - right=self.right.get_sql(quote_char=quote_char, **kwargs), + left=self.left.get_sql(ctx), + right=self.right.get_sql(ctx), ) - if with_alias: - return format_alias_sql(sql, self.alias, **kwargs) + if ctx.with_alias: + return format_alias_sql(sql, self.alias, ctx) return sql @@ -970,13 +904,14 @@ def replace_table( # type:ignore[return] """ self.term = self.term.replace_table(current_table, new_table) - def get_sql(self, subquery: Any = None, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: + container_ctx = ctx.copy(subquery=True) sql = "{term} {not_}IN {container}".format( - term=self.term.get_sql(**kwargs), - container=self.container.get_sql(subquery=True, **kwargs), + term=self.term.get_sql(ctx), + container=self.container.get_sql(container_ctx), not_="NOT " if self._is_negated else "", ) - return format_alias_sql(sql, self.alias, **kwargs) + return format_alias_sql(sql, self.alias, ctx) @builder def negate(self) -> "Self": # type:ignore[return,override] @@ -998,7 +933,7 @@ def nodes_(self) -> Iterator[NodeT]: @property def is_aggregate(self) -> bool | None: # type:ignore[override] - return self.term.is_aggregate # type:ignore[has-type] + return self.term.is_aggregate class BetweenCriterion(RangeCriterion): @@ -1018,24 +953,24 @@ def replace_table( # type:ignore[return] """ self.term = self.term.replace_table(current_table, new_table) - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: # FIXME escape sql = "{term} BETWEEN {start} AND {end}".format( - term=self.term.get_sql(**kwargs), - start=self.start.get_sql(**kwargs), - end=self.end.get_sql(**kwargs), + term=self.term.get_sql(ctx), + start=self.start.get_sql(ctx), + end=self.end.get_sql(ctx), ) - return format_alias_sql(sql, self.alias, **kwargs) + return format_alias_sql(sql, self.alias, ctx) class PeriodCriterion(RangeCriterion): - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: sql = "{term} FROM {start} TO {end}".format( - term=self.term.get_sql(**kwargs), - start=self.start.get_sql(**kwargs), - end=self.end.get_sql(**kwargs), + term=self.term.get_sql(ctx), + start=self.start.get_sql(ctx), + end=self.end.get_sql(ctx), ) - return format_alias_sql(sql, self.alias, **kwargs) + return format_alias_sql(sql, self.alias, ctx) class BitwiseAndCriterion(Criterion): @@ -1065,12 +1000,12 @@ def replace_table( # type:ignore[return] """ self.term = self.term.replace_table(current_table, new_table) - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: sql = "({term} & {value})".format( - term=self.term.get_sql(**kwargs), + term=self.term.get_sql(ctx), value=self.value, ) - return format_alias_sql(sql, self.alias, **kwargs) + return format_alias_sql(sql, self.alias, ctx) class NullCriterion(Criterion): @@ -1098,22 +1033,24 @@ def replace_table( # type:ignore[return] """ self.term = self.term.replace_table(current_table, new_table) - def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: sql = "{term} IS NULL".format( - term=self.term.get_sql(**kwargs), + term=self.term.get_sql(ctx), ) - return format_alias_sql(sql, self.alias, **kwargs) + return format_alias_sql(sql, self.alias, ctx) class ComplexCriterion(BasicCriterion): - def get_sql(self, subcriterion: bool = False, **kwargs: Any) -> str: # type:ignore[override] + def get_sql(self, ctx: SqlContext) -> str: + left_ctx = ctx.copy(subcriterion=self.needs_brackets(self.left)) + right_ctx = ctx.copy(subcriterion=self.needs_brackets(self.right)) sql = "{left} {comparator} {right}".format( comparator=self.comparator.value, - left=self.left.get_sql(subcriterion=self.needs_brackets(self.left), **kwargs), - right=self.right.get_sql(subcriterion=self.needs_brackets(self.right), **kwargs), + left=self.left.get_sql(left_ctx), + right=self.right.get_sql(right_ctx), ) - if subcriterion: + if ctx.subcriterion: return "({criterion})".format(criterion=sql) return sql @@ -1223,21 +1160,21 @@ def right_needs_parens(self, curr_op, right_op) -> bool: # e.g. ... - A / B, ... - A * B return right_op in self.add_order - def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: left_op, right_op = [getattr(side, "operator", None) for side in [self.left, self.right]] arithmetic_sql = "{left}{operator}{right}".format( operator=self.operator.value, left=("({})" if self.left_needs_parens(self.operator, left_op) else "{}").format( - self.left.get_sql(**kwargs) + self.left.get_sql(ctx) ), right=("({})" if self.right_needs_parens(self.operator, right_op) else "{}").format( - self.right.get_sql(**kwargs) + self.right.get_sql(ctx) ), ) - if with_alias: - return format_alias_sql(arithmetic_sql, self.alias, **kwargs) + if ctx.with_alias: + return format_alias_sql(arithmetic_sql, self.alias, ctx) return arithmetic_sql @@ -1263,7 +1200,7 @@ def is_aggregate(self) -> bool | None: # type:ignore[override] # True if all criterions/cases are True or None. None all cases are None. Otherwise, False return resolve_is_aggregate( [criterion.is_aggregate or term.is_aggregate for criterion, term in self._cases] - + [self._else.is_aggregate if self._else else None] # type:ignore[has-type] + + [self._else.is_aggregate if self._else else None] ) @builder @@ -1295,25 +1232,25 @@ def replace_table( # type:ignore[return] @builder def else_(self, term: Any) -> "Self": - self._else = self.wrap_constant(term) # type:ignore[assignment] + self._else = self.wrap_constant(term) return self - def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: if not self._cases: raise CaseException("At least one 'when' case is required for a CASE statement.") + when_then_else_ctx = ctx.copy(with_alias=False) cases = " ".join( "WHEN {when} THEN {then}".format( - when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs) + when=criterion.get_sql(when_then_else_ctx), then=term.get_sql(when_then_else_ctx) ) for criterion, term in self._cases ) - else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else "" - + else_ = " ELSE {}".format(self._else.get_sql(when_then_else_ctx)) if self._else else "" case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_) - if with_alias: - return format_alias_sql(case_sql, self.alias, **kwargs) + if ctx.with_alias: + return format_alias_sql(case_sql, self.alias, ctx) return case_sql @@ -1327,10 +1264,10 @@ def nodes_(self) -> Iterator[NodeT]: yield self # type:ignore[misc] yield from self.term.nodes_() - def get_sql(self, **kwargs: Any) -> str: - kwargs["subcriterion"] = True - sql = "NOT {term}".format(term=self.term.get_sql(**kwargs)) - return format_alias_sql(sql, self.alias, **kwargs) + def get_sql(self, ctx: SqlContext) -> str: + not_ctx = ctx.copy(subcriterion=True) + sql = "NOT {term}".format(term=self.term.get_sql(not_ctx)) + return format_alias_sql(sql, self.alias, ctx) @ignore_copy def __getattr__(self, name: str) -> Any: @@ -1377,9 +1314,9 @@ def nodes_(self) -> Iterator[NodeT]: yield self # type:ignore[misc] yield from self.term.nodes_() - def get_sql(self, **kwargs: Any) -> str: - sql = "{term} ALL".format(term=self.term.get_sql(**kwargs)) - return format_alias_sql(sql, self.alias, **kwargs) + def get_sql(self, ctx: SqlContext) -> str: + sql = "{term} ALL".format(term=self.term.get_sql(ctx)) + return format_alias_sql(sql, self.alias, ctx) class CustomFunction: @@ -1419,7 +1356,7 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: def nodes_(self) -> Iterator[NodeT]: yield self # type:ignore[misc] for arg in self.args: - yield from arg.nodes_() # type:ignore[union-attr] + yield from arg.nodes_() @property def is_aggregate(self) -> bool | None: # type:ignore[override] @@ -1429,9 +1366,7 @@ def is_aggregate(self) -> bool | None: # type:ignore[override] :returns: True if the function accepts one argument and that argument is aggregate. """ - return resolve_is_aggregate( - [arg.is_aggregate for arg in self.args] # type:ignore[has-type,union-attr] - ) + return resolve_is_aggregate([arg.is_aggregate for arg in self.args]) @builder def replace_table( # type:ignore[return] @@ -1447,47 +1382,38 @@ def replace_table( # type:ignore[return] :return: A copy of the criterion with the tables replaced. """ - self.args = [ - param.replace_table(current_table, new_table) # type:ignore[misc,union-attr] - for param in self.args - ] + self.args = [param.replace_table(current_table, new_table) for param in self.args] - def get_special_params_sql(self, **kwargs: Any) -> Any: + def get_special_params_sql(self, ctx: SqlContext) -> Any: pass @staticmethod - def get_arg_sql(arg, **kwargs) -> str: - return arg.get_sql(with_alias=False, **kwargs) if hasattr(arg, "get_sql") else str(arg) + def get_arg_sql(arg, ctx: SqlContext) -> str: + arg_ctx = ctx.copy(with_alias=False) + return arg.get_sql(arg_ctx) if hasattr(arg, "get_sql") else str(arg) - def get_function_sql(self, **kwargs: Any) -> str: + def get_function_sql(self, ctx: SqlContext) -> str: # pylint: disable=E1111 - special_params_sql = self.get_special_params_sql(**kwargs) + special_params_sql = self.get_special_params_sql(ctx) return "{name}({args}{special})".format( name=self.name, - args=",".join(self.get_arg_sql(arg, **kwargs) for arg in self.args), + args=",".join(self.get_arg_sql(arg, ctx) for arg in self.args), special=(" " + special_params_sql) if special_params_sql else "", ) - def get_sql(self, **kwargs: Any) -> str: - with_alias = kwargs.pop("with_alias", False) - with_namespace = kwargs.pop("with_namespace", False) - quote_char = kwargs.pop("quote_char", None) - dialect = kwargs.pop("dialect", None) - + def get_sql(self, ctx: SqlContext) -> str: # FIXME escape - function_sql = self.get_function_sql( - with_namespace=with_namespace, quote_char=quote_char, dialect=dialect, **kwargs - ) + function_sql = self.get_function_sql(ctx) if self.schema is not None: function_sql = "{schema}.{function}".format( - schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs), + schema=self.schema.get_sql(ctx), function=function_sql, ) - if with_alias: - return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs) + if ctx.with_alias: + return format_alias_sql(function_sql, self.alias, ctx) return function_sql @@ -1506,15 +1432,15 @@ def filter(self, *filters: Any) -> AnalyticFunction: # type:ignore[return] self._include_filter = True self._filters += filters - def get_filter_sql(self, **kwargs: Any) -> str: # type:ignore[return] + def get_filter_sql(self, ctx: SqlContext) -> str: # type:ignore[return] if self._include_filter: - criterions = Criterion.all(self._filters).get_sql(**kwargs) # type:ignore[attr-defined] + criterions = Criterion.all(self._filters).get_sql(ctx) # type:ignore[attr-defined] return f"WHERE {criterions}" # TODO: handle case of `not self._include_filter` - def get_function_sql(self, **kwargs: Any) -> str: - sql = super().get_function_sql(**kwargs) - filter_sql = self.get_filter_sql(**kwargs) + def get_function_sql(self, ctx: SqlContext) -> str: + sql = super().get_function_sql(ctx) + filter_sql = self.get_filter_sql(ctx) if self._include_filter: sql += " FILTER({filter_sql})".format(filter_sql=filter_sql) @@ -1544,23 +1470,22 @@ def orderby(self, *terms: Any, **kwargs: Any) -> "Self": # type:ignore[return] self._include_over = True self._orderbys += [(term, kwargs.get("order")) for term in terms] - def _orderby_field(self, field: Field, orient: Order | None, **kwargs: Any) -> str: + def _orderby_field(self, field: Field, orient: Order | None, ctx: SqlContext) -> str: if orient is None: - return field.get_sql(**kwargs) + return field.get_sql(ctx) return "{field} {orient}".format( - field=field.get_sql(**kwargs), + field=field.get_sql(ctx), orient=orient.value, ) - def get_partition_sql(self, **kwargs: Any) -> str: + def get_partition_sql(self, ctx: SqlContext) -> str: terms = [] if self._partition: terms.append( "PARTITION BY {args}".format( args=",".join( - p.get_sql(**kwargs) if hasattr(p, "get_sql") else str(p) - for p in self._partition + p.get_sql(ctx) if hasattr(p, "get_sql") else str(p) for p in self._partition ) ) ) @@ -1569,17 +1494,16 @@ def get_partition_sql(self, **kwargs: Any) -> str: terms.append( "ORDER BY {orderby}".format( orderby=",".join( - self._orderby_field(field, orient, **kwargs) - for field, orient in self._orderbys + self._orderby_field(field, orient, ctx) for field, orient in self._orderbys ) ) ) return " ".join(terms) - def get_function_sql(self, **kwargs: Any) -> str: - function_sql = super().get_function_sql(**kwargs) - partition_sql = self.get_partition_sql(**kwargs) + def get_function_sql(self, ctx: SqlContext) -> str: + function_sql = super().get_function_sql(ctx) + partition_sql = self.get_partition_sql(ctx) sql = function_sql if self._include_over: @@ -1640,8 +1564,8 @@ def get_frame_sql(self) -> str: upper=upper, ) - def get_partition_sql(self, **kwargs: Any) -> str: - partition_sql = super().get_partition_sql(**kwargs) + def get_partition_sql(self, ctx: SqlContext) -> str: + partition_sql = super().get_partition_sql(ctx) if not self.frame and not self.bound: return partition_sql @@ -1658,7 +1582,7 @@ def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: def ignore_nulls(self) -> "Self": # type:ignore[return] self._ignore_nulls = True - def get_special_params_sql(self, **kwargs: Any) -> str | None: + def get_special_params_sql(self, ctx: SqlContext) -> str | None: if self._ignore_nulls: return "IGNORE NULLS" @@ -1722,11 +1646,9 @@ def __init__( self.smallest = label def __str__(self) -> str: - return self.get_sql() - - def get_sql(self, **kwargs: Any) -> str: - dialect = cast(Dialects, self.dialect or kwargs.get("dialect")) + return self.get_sql(DEFAULT_SQL_CONTEXT) + def get_sql(self, ctx: SqlContext) -> str: if self.largest == "MICROSECOND": expr = getattr(self, "microseconds") unit = "MICROSECOND" @@ -1762,7 +1684,9 @@ def get_sql(self, **kwargs: Any) -> str: else: unit = self.largest - return self.templates.get(dialect, "INTERVAL '{expr} {unit}'").format(expr=expr, unit=unit) + return self.templates.get(ctx.dialect, "INTERVAL '{expr} {unit}'").format( + expr=expr, unit=unit + ) class Pow(Function): @@ -1790,7 +1714,7 @@ def __init__(self, name: str) -> None: super().__init__(alias=None) self.name = name - def get_sql(self, **kwargs: Any) -> str: + def get_sql(self, ctx: SqlContext) -> str: return self.name @@ -1810,10 +1734,10 @@ def __init__(self, field, zone, interval=False, alias=None) -> None: self.zone = zone self.interval = interval - def get_sql(self, **kwargs) -> str: + def get_sql(self, ctx: SqlContext) -> str: sql = "{name} AT TIME ZONE {interval}'{zone}'".format( - name=self.field.get_sql(**kwargs), + name=self.field.get_sql(ctx), interval="INTERVAL " if self.interval else "", zone=self.zone, ) - return format_alias_sql(sql, self.alias, **kwargs) + return format_alias_sql(sql, self.alias, ctx) diff --git a/pypika_tortoise/utils.py b/pypika_tortoise/utils.py index 8d03322..fc03be7 100644 --- a/pypika_tortoise/utils.py +++ b/pypika_tortoise/utils.py @@ -2,6 +2,8 @@ from typing import Any, Callable, Type, TypeVar +from .context import SqlContext + T_Retval = TypeVar("T_Retval") T_Self = TypeVar("T_Self") @@ -78,17 +80,14 @@ def format_quotes(value: Any, quote_char: str | None) -> str: def format_alias_sql( sql: str, alias: str | None, - quote_char: str | None = None, - alias_quote_char: str | None = None, - as_keyword: bool = False, - **kwargs: Any, + ctx: SqlContext, ) -> str: if alias is None: return sql return "{sql}{_as}{alias}".format( sql=sql, - _as=" AS " if as_keyword else " ", - alias=format_quotes(alias, alias_quote_char or quote_char), + _as=" AS " if ctx.as_keyword else " ", + alias=format_quotes(alias, ctx.alias_quote_char or ctx.quote_char), ) diff --git a/pyproject.toml b/pyproject.toml index 5bf26bb..718d0c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pypika-tortoise" -version = "0.4.0" +version = "0.5.0" description = "Forked from pypika and streamline just for tortoise-orm" authors = ["long2ice "] license = "Apache-2.0" @@ -41,6 +41,7 @@ target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] pretty = true python_version = "3.8" ignore_missing_imports = true +warn_unused_ignores = true [tool.ruff] line-length = 100 diff --git a/tests/test_criterions.py b/tests/test_criterions.py index a8bf1f5..2f7dadd 100644 --- a/tests/test_criterions.py +++ b/tests/test_criterions.py @@ -3,6 +3,7 @@ from pypika_tortoise import Criterion, EmptyCriterion, Field, Table from pypika_tortoise import functions as fn +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.queries import QueryBuilder from pypika_tortoise.terms import Mod @@ -19,7 +20,7 @@ def test__criterion_with_alias(self): self.assertEqual('"foo"="bar"', str(c1)) self.assertEqual( '"foo"="bar" "criterion"', - c1.get_sql(with_alias=True, quote_char='"', alias_quote_char='"'), + c1.get_sql(DEFAULT_SQL_CONTEXT.copy(with_alias=True)), ) def test__criterion_eq_number(self): diff --git a/tests/test_data_types.py b/tests/test_data_types.py index 6e2333f..15cde34 100644 --- a/tests/test_data_types.py +++ b/tests/test_data_types.py @@ -1,15 +1,16 @@ import unittest import uuid +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import ValueWrapper class StringTests(unittest.TestCase): def test_inline_string_concatentation(self): - self.assertEqual("'it''s'", ValueWrapper("it's").get_sql()) + self.assertEqual("'it''s'", ValueWrapper("it's").get_sql(DEFAULT_SQL_CONTEXT)) class UuidTests(unittest.TestCase): def test_uuid_string_generation(self): id = uuid.uuid4() - self.assertEqual("'{}'".format(id), ValueWrapper(id).get_sql()) + self.assertEqual("'{}'".format(id), ValueWrapper(id).get_sql(DEFAULT_SQL_CONTEXT)) diff --git a/tests/test_date_math.py b/tests/test_date_math.py index 4dd00e1..2268c09 100644 --- a/tests/test_date_math.py +++ b/tests/test_date_math.py @@ -2,7 +2,10 @@ from pypika_tortoise import Field as F from pypika_tortoise import Interval -from pypika_tortoise.enums import Dialects +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT +from pypika_tortoise.dialects.mysql import MySQLQuery +from pypika_tortoise.dialects.oracle import OracleQuery +from pypika_tortoise.dialects.postgresql import PostgreSQLQuery dt = F("dt") @@ -123,45 +126,37 @@ def test_add_value_complex_expressions(self): class DialectIntervalTests(unittest.TestCase): def test_mysql_dialect_uses_single_quotes_around_expression_in_an_interval(self): - c = Interval(days=1).get_sql(dialect=Dialects.MYSQL) + c = Interval(days=1).get_sql(MySQLQuery.SQL_CONTEXT) self.assertEqual("INTERVAL '1' DAY", c) def test_oracle_dialect_uses_single_quotes_around_expression_in_an_interval(self): - c = Interval(days=1).get_sql(dialect=Dialects.ORACLE) + c = Interval(days=1).get_sql(OracleQuery.SQL_CONTEXT) self.assertEqual("INTERVAL '1' DAY", c) - def test_vertica_dialect_uses_single_quotes_around_interval(self): - c = Interval(days=1).get_sql(dialect=Dialects.VERTICA) - self.assertEqual("INTERVAL '1 DAY'", c) - - def test_redshift_dialect_uses_single_quotes_around_interval(self): - c = Interval(days=1).get_sql(dialect=Dialects.REDSHIFT) - self.assertEqual("INTERVAL '1 DAY'", c) - def test_postgresql_dialect_uses_single_quotes_around_interval(self): - c = Interval(days=1).get_sql(dialect=Dialects.POSTGRESQL) + c = Interval(days=1).get_sql(PostgreSQLQuery.SQL_CONTEXT) self.assertEqual("INTERVAL '1 DAY'", c) class TestNegativeIntervals(unittest.TestCase): def test_day(self): - c = Interval(days=-1).get_sql() + c = Interval(days=-1).get_sql(DEFAULT_SQL_CONTEXT) self.assertEqual("INTERVAL '-1 DAY'", c) def test_week(self): - c = Interval(weeks=-1).get_sql() + c = Interval(weeks=-1).get_sql(DEFAULT_SQL_CONTEXT) self.assertEqual("INTERVAL '-1 WEEK'", c) def test_month(self): - c = Interval(months=-1).get_sql() + c = Interval(months=-1).get_sql(DEFAULT_SQL_CONTEXT) self.assertEqual("INTERVAL '-1 MONTH'", c) def test_year(self): - c = Interval(years=-1).get_sql() + c = Interval(years=-1).get_sql(DEFAULT_SQL_CONTEXT) self.assertEqual("INTERVAL '-1 YEAR'", c) def test_year_month(self): - c = Interval(years=-1, months=-4).get_sql() + c = Interval(years=-1, months=-4).get_sql(DEFAULT_SQL_CONTEXT) self.assertEqual("INTERVAL '-1-4 YEAR_MONTH'", c) diff --git a/tests/test_formats.py b/tests/test_formats.py index 3b617df..9f27f2b 100644 --- a/tests/test_formats.py +++ b/tests/test_formats.py @@ -2,6 +2,7 @@ from pypika_tortoise import Query, Tables from pypika_tortoise import functions as fn +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT class QuoteTests(unittest.TestCase): @@ -47,7 +48,7 @@ def test_replace_quote_char_in_complex_query(self): "`foo` `foo_two`,`bar` " "FROM `efg`" ") `sq1` ON `sq0`.`foo`=`sq1`.`foo_two`", - self.query.get_sql(quote_char="`"), + self.query.get_sql(DEFAULT_SQL_CONTEXT.copy(quote_char="`")), ) def test_no_quote_char_in_complex_query(self): @@ -65,5 +66,5 @@ def test_no_quote_char_in_complex_query(self): "foo foo_two,bar " "FROM efg" ") sq1 ON sq0.foo=sq1.foo_two", - self.query.get_sql(quote_char=None), + self.query.get_sql(DEFAULT_SQL_CONTEXT.copy(quote_char=None)), ) diff --git a/tests/test_functions.py b/tests/test_functions.py index cf12ff1..dd0d8a1 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -7,13 +7,15 @@ from pypika_tortoise import Schema from pypika_tortoise import Table as T from pypika_tortoise import functions as fn -from pypika_tortoise.enums import Dialects, SqlTypes +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT +from pypika_tortoise.dialects.postgresql import PostgreSQLQuery +from pypika_tortoise.enums import SqlTypes class FunctionTests(unittest.TestCase): def test_dialect_propagation(self): func = fn.Function("func", ["a"], ["b"]) - self.assertEqual("func(ARRAY['a'],ARRAY['b'])", func.get_sql(dialect=Dialects.POSTGRESQL)) + self.assertEqual("func(ARRAY['a'],ARRAY['b'])", func.get_sql(PostgreSQLQuery.SQL_CONTEXT)) def test_is_aggregate_None_for_non_aggregate_function_or_function_with_no_aggregate_functions( self, @@ -31,13 +33,15 @@ class SchemaTests(unittest.TestCase): def test_schema_no_schema_in_sql_when_none_set(self): func = fn.Function("my_proc", 1, 2, 3) - self.assertEqual("my_proc(1,2,3)", func.get_sql(quote_char='"')) + self.assertEqual("my_proc(1,2,3)", func.get_sql(DEFAULT_SQL_CONTEXT.copy(quote_char='"'))) def test_schema_included_in_function_sql(self): a = Schema("a") func = fn.Function("my_proc", 1, 2, 3, schema=a) - self.assertEqual('"a".my_proc(1,2,3)', func.get_sql(quote_char='"')) + self.assertEqual( + '"a".my_proc(1,2,3)', func.get_sql(DEFAULT_SQL_CONTEXT.copy(quote_char='"')) + ) class ArithmeticTests(unittest.TestCase): diff --git a/tests/test_joins.py b/tests/test_joins.py index 41ecfc1..4121d47 100644 --- a/tests/test_joins.py +++ b/tests/test_joins.py @@ -13,6 +13,7 @@ Tables, ) from pypika_tortoise import functions as fn +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -270,7 +271,10 @@ def test_join_using_multiple_fields(self): def test_join_using_with_quote_char(self): query = Query.from_(self.table0).join(self.table1).using("foo", "bar").select("*") - self.assertEqual("SELECT * FROM abc JOIN efg USING (foo,bar)", query.get_sql(quote_char="")) + self.assertEqual( + "SELECT * FROM abc JOIN efg USING (foo,bar)", + query.get_sql(DEFAULT_SQL_CONTEXT.copy(quote_char="")), + ) def test_join_using_without_fields_raises_exception(self): with self.assertRaises(JoinException): @@ -972,16 +976,6 @@ def test_union_as_subquery(self): str(q), ) - def test_union_with_no_quote_char(self): - abc, efg = Tables("abc", "efg") - hij = Query.from_(abc).select(abc.t).union(Query.from_(efg).select(efg.t)) - q = Query.from_(hij).select(fn.Avg(hij.t)) - - self.assertEqual( - "SELECT AVG(sq0.t) FROM ((SELECT t FROM abc) UNION (SELECT t FROM efg)) sq0", - q.get_sql(quote_char=None), - ) - class InsertQueryJoinTests(unittest.TestCase): def test_join_table_on_insert_query(self): @@ -1121,16 +1115,6 @@ def test_intersect_as_subquery(self): str(q), ) - def test_intersect_with_no_quote_char(self): - abc, efg = Tables("abc", "efg") - hij = Query.from_(abc).select(abc.t).intersect(Query.from_(efg).select(efg.t)) - q = Query.from_(hij).select(fn.Avg(hij.t)) - - self.assertEqual( - "SELECT AVG(sq0.t) FROM ((SELECT t FROM abc) INTERSECT (SELECT t FROM efg)) sq0", - q.get_sql(quote_char=None), - ) - class MinusTests(unittest.TestCase): table1, table2 = Tables("abc", "efg") @@ -1226,16 +1210,6 @@ def test_minus_as_subquery(self): str(q), ) - def test_minus_with_no_quote_char(self): - abc, efg = Tables("abc", "efg") - hij = Query.from_(abc).select(abc.t).minus(Query.from_(efg).select(efg.t)) - q = Query.from_(hij).select(fn.Avg(hij.t)) - - self.assertEqual( - "SELECT AVG(sq0.t) FROM ((SELECT t FROM abc) MINUS (SELECT t FROM efg)) sq0", - q.get_sql(quote_char=None), - ) - class ExceptOfTests(unittest.TestCase): table1, table2 = Tables("abc", "efg") @@ -1309,13 +1283,3 @@ def test_except_as_subquery(self): 'SELECT AVG("sq0"."t") FROM ((SELECT "t" FROM "abc") EXCEPT (SELECT "t" FROM "efg")) "sq0"', str(q), ) - - def test_except_with_no_quote_char(self): - abc, efg = Tables("abc", "efg") - hij = Query.from_(abc).select(abc.t).except_of(Query.from_(efg).select(efg.t)) - q = Query.from_(hij).select(fn.Avg(hij.t)) - - self.assertEqual( - "SELECT AVG(sq0.t) FROM ((SELECT t FROM abc) EXCEPT (SELECT t FROM efg)) sq0", - q.get_sql(quote_char=None), - ) diff --git a/tests/test_negation.py b/tests/test_negation.py index 9624dec..5d959c1 100644 --- a/tests/test_negation.py +++ b/tests/test_negation.py @@ -2,6 +2,7 @@ from pypika_tortoise import Tables from pypika_tortoise import functions as fn +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import ValueWrapper @@ -11,19 +12,24 @@ class NegationTests(unittest.TestCase): def test_negate_wrapped_float(self): q = -ValueWrapper(1.0) - self.assertEqual("-1.0", q.get_sql()) + self.assertEqual("-1.0", q.get_sql(DEFAULT_SQL_CONTEXT)) def test_negate_wrapped_int(self): q = -ValueWrapper(1) - self.assertEqual("-1", q.get_sql()) + self.assertEqual("-1", q.get_sql(DEFAULT_SQL_CONTEXT)) def test_negate_field(self): q = -self.table_abc.foo - self.assertEqual('-"abc"."foo"', q.get_sql(with_namespace=True, quote_char='"')) + self.assertEqual( + '-"abc"."foo"', q.get_sql(DEFAULT_SQL_CONTEXT.copy(with_namespace=True, quote_char='"')) + ) def test_negate_function(self): q = -fn.Sum(self.table_abc.foo) - self.assertEqual('-SUM("abc"."foo")', q.get_sql(with_namespace=True, quote_char='"')) + self.assertEqual( + '-SUM("abc"."foo")', + q.get_sql(DEFAULT_SQL_CONTEXT.copy(with_namespace=True, quote_char='"')), + ) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 71a9a79..834fbf7 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -2,10 +2,12 @@ from datetime import date from pypika_tortoise import Parameter, Query, Tables, ValueWrapper +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.dialects.mssql import MSSQLQuery from pypika_tortoise.dialects.mysql import MySQLQuery +from pypika_tortoise.dialects.oracle import OracleQuery from pypika_tortoise.dialects.postgresql import PostgreSQLQuery -from pypika_tortoise.enums import Dialects +from pypika_tortoise.dialects.sqlite import SQLLiteQuery from pypika_tortoise.functions import Upper from pypika_tortoise.terms import Case, Parameterizer @@ -82,27 +84,27 @@ def test_join(self): ) def test_qmark_parameter(self): - self.assertEqual("?", Parameter("?").get_sql()) + self.assertEqual("?", Parameter("?").get_sql(DEFAULT_SQL_CONTEXT)) def test_oracle(self): - self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.ORACLE)) - self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.ORACLE)) + self.assertEqual("?", Parameter(idx=1).get_sql(OracleQuery.SQL_CONTEXT)) + self.assertEqual("?", Parameter(idx=2).get_sql(OracleQuery.SQL_CONTEXT)) def test_mssql(self): - self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.MSSQL)) - self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.MSSQL)) + self.assertEqual("?", Parameter(idx=1).get_sql(MSSQLQuery.SQL_CONTEXT)) + self.assertEqual("?", Parameter(idx=2).get_sql(MSSQLQuery.SQL_CONTEXT)) def test_mysql(self): - self.assertEqual("%s", Parameter(idx=1).get_sql(dialect=Dialects.MYSQL)) - self.assertEqual("%s", Parameter(idx=2).get_sql(dialect=Dialects.MYSQL)) + self.assertEqual("%s", Parameter(idx=1).get_sql(MySQLQuery.SQL_CONTEXT)) + self.assertEqual("%s", Parameter(idx=2).get_sql(MySQLQuery.SQL_CONTEXT)) def test_postgres(self): - self.assertEqual("$1", Parameter(idx=1).get_sql(dialect=Dialects.POSTGRESQL)) - self.assertEqual("$2", Parameter(idx=2).get_sql(dialect=Dialects.POSTGRESQL)) + self.assertEqual("$1", Parameter(idx=1).get_sql(PostgreSQLQuery.SQL_CONTEXT)) + self.assertEqual("$2", Parameter(idx=2).get_sql(PostgreSQLQuery.SQL_CONTEXT)) def test_sqlite(self): - self.assertEqual("?", Parameter(idx=1).get_sql(dialect=Dialects.SQLITE)) - self.assertEqual("?", Parameter(idx=2).get_sql(dialect=Dialects.SQLITE)) + self.assertEqual("?", Parameter(idx=1).get_sql(SQLLiteQuery.SQL_CONTEXT)) + self.assertEqual("?", Parameter(idx=2).get_sql(SQLLiteQuery.SQL_CONTEXT)) class ParameterizerTests(unittest.TestCase): @@ -111,10 +113,9 @@ class ParameterizerTests(unittest.TestCase): def test_param_insert(self): q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, "foo") - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer) + sql, values = q.get_parameterized_sql() self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql) - self.assertEqual([1, 2.2, "foo"], parameterizer.values) + self.assertEqual([1, 2.2, "foo"], values) def test_select_join_in_mysql(self): q = ( @@ -127,13 +128,12 @@ def test_select_join_in_mysql(self): .limit(10) ) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.MYSQL) + sql, values = q.get_parameterized_sql() self.assertEqual( "SELECT * FROM `abc` JOIN `efg` ON `abc`.`id`=`efg`.`abc_id` WHERE `abc`.`category`=%s AND `efg`.`date`>=%s LIMIT %s", sql, ) - self.assertEqual(["foobar", date(2024, 2, 22), 10], parameterizer.values) + self.assertEqual(["foobar", date(2024, 2, 22), 10], values) def test_select_subquery_in_postgres(self): q = ( @@ -150,13 +150,12 @@ def test_select_subquery_in_postgres(self): .limit(10) ) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.POSTGRESQL) + sql, values = q.get_parameterized_sql() self.assertEqual( 'SELECT * FROM "abc" WHERE "category"=$1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=$2) LIMIT $3', sql, ) - self.assertEqual(["foobar", date(2024, 2, 22), 10], parameterizer.values) + self.assertEqual(["foobar", date(2024, 2, 22), 10], values) def test_join_in_postgres(self): subquery = ( @@ -173,14 +172,13 @@ def test_join_in_postgres(self): .where(self.table_abc.bar == "bar") ) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer, dialect=Dialects.POSTGRESQL) + sql, values = q.get_parameterized_sql() self.assertEqual( 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=$1)' ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=$2', sql, ) - self.assertEqual(["buz", "bar"], parameterizer.values) + self.assertEqual(["buz", "bar"], values) def test_function_parameter(self): q = ( @@ -188,20 +186,18 @@ def test_function_parameter(self): .select("*") .where(self.table_abc.category == Upper(ValueWrapper("foobar"))) ) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer) + sql, values = q.get_parameterized_sql() self.assertEqual('SELECT * FROM "abc" WHERE "category"=UPPER(?)', sql) - self.assertEqual(["foobar"], parameterizer.values) + self.assertEqual(["foobar"], values) def test_case_when_in_select(self): q = Query.from_(self.table_abc).select( Case().when(self.table_abc.category == "foobar", 1).else_(2) ) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer) + sql, values = q.get_parameterized_sql() self.assertEqual('SELECT CASE WHEN "category"=? THEN ? ELSE ? END FROM "abc"', sql) - self.assertEqual(["foobar", 1, 2], parameterizer.values) + self.assertEqual(["foobar", 1, 2], values) def test_case_when_in_where(self): q = ( @@ -212,31 +208,28 @@ def test_case_when_in_where(self): > Case().when(self.table_abc.category == "foobar", 1).else_(2) ) ) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer) + sql, values = q.get_parameterized_sql() self.assertEqual( 'SELECT * FROM "abc" WHERE "category_int">CASE WHEN "category"=? THEN ? ELSE ? END', sql, ) - self.assertEqual(["foobar", 1, 2], parameterizer.values) + self.assertEqual(["foobar", 1, 2], values) def test_limit_and_offest(self): q = Query.from_(self.table_abc).select("*").limit(10).offset(5) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer) + sql, values = q.get_parameterized_sql() self.assertEqual('SELECT * FROM "abc" LIMIT ? OFFSET ?', sql) - self.assertEqual([10, 5], parameterizer.values) + self.assertEqual([10, 5], values) def test_limit_and_offest_in_mssql(self): q = MSSQLQuery.from_(self.table_abc).select("*").limit(10).offset(5) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer) + sql, values = q.get_parameterized_sql() self.assertEqual( 'SELECT * FROM "abc" ORDER BY (SELECT 0) OFFSET ? ROWS FETCH NEXT ? ROWS ONLY', sql ) - self.assertEqual([5, 10], parameterizer.values) + self.assertEqual([5, 10], values) def test_placeholder_factory(self): parameterizer = Parameterizer(placeholder_factory=lambda _: "%s") param = parameterizer.create_param(1) - self.assertEqual("%s", param.get_sql()) + self.assertEqual("%s", param.get_sql(PostgreSQLQuery.SQL_CONTEXT)) diff --git a/tests/test_selects.py b/tests/test_selects.py index 044e960..7a09579 100644 --- a/tests/test_selects.py +++ b/tests/test_selects.py @@ -17,6 +17,7 @@ Tables, ) from pypika_tortoise import functions as fn +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import Field, ValueWrapper __author__ = "Timothy Heys" @@ -641,7 +642,7 @@ def test_groupby__no_alias(self): self.assertEqual( 'SELECT SUM("foo"),"bar" "bar01" FROM "abc" GROUP BY "bar"', - q.get_sql(groupby_alias=False), + q.get_sql(DEFAULT_SQL_CONTEXT.copy(groupby_alias=False)), ) def test_groupby__alias_platforms(self): @@ -654,11 +655,7 @@ def test_groupby__alias_platforms(self): ]: q = query_cls.from_(self.t).select(fn.Sum(self.t.foo), bar).groupby(bar) - quote_char = ( - query_cls._builder().QUOTE_CHAR - if isinstance(query_cls._builder().QUOTE_CHAR, str) - else '"' - ) + quote_char = query_cls.SQL_CONTEXT.quote_char self.assertEqual( "SELECT " @@ -844,7 +841,7 @@ def test_orderby_no_alias(self): self.assertEqual( 'SELECT SUM("foo"),"bar" "bar01" FROM "abc" ORDER BY "bar"', - q.get_sql(orderby_alias=False), + q.get_sql(DEFAULT_SQL_CONTEXT.copy(orderby_alias=False)), ) def test_orderby_alias(self): @@ -1240,5 +1237,5 @@ def test_extraneous_quotes(self): "SELECT t1.value FROM table1 t1 " "JOIN table2 t2 ON t1.Value " "BETWEEN t2.start AND t2.end", - query.get_sql(quote_char=None), + query.get_sql(DEFAULT_SQL_CONTEXT.copy(quote_char="")), ) diff --git a/tests/test_tables.py b/tests/test_tables.py index 12acf5a..4116072 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -11,6 +11,7 @@ Table, Tables, ) +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -25,7 +26,10 @@ def test_table_sql(self): def test_table_with_alias(self): table = Table("test_table").as_("my_table") - self.assertEqual('"test_table" "my_table"', table.get_sql(with_alias=True, quote_char='"')) + self.assertEqual( + '"test_table" "my_table"', + table.get_sql(DEFAULT_SQL_CONTEXT.copy(with_alias=True, quote_char='"')), + ) def test_schema_table_attr(self): table = Schema("x_schema").test_table @@ -168,36 +172,36 @@ class TableDialectTests(unittest.TestCase): def test_table_with_default_query_cls(self): table = Table("abc") q = table.select("1") - self.assertIs(q.dialect, None) + self.assertIs(q.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) def test_table_with_dialect_query_cls(self): table = Table("abc", query_cls=SQLLiteQuery) q = table.select("1") - self.assertIs(q.dialect, Dialects.SQLITE) + self.assertIs(q.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) def test_table_factory_with_default_query_cls(self): table = Query.Table("abc") q = table.select("1") - self.assertIs(q.dialect, None) + self.assertIs(q.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) def test_table_factory_with_dialect_query_cls(self): table = SQLLiteQuery.Table("abc") q = table.select("1") - self.assertIs(q.dialect, Dialects.SQLITE) + self.assertIs(q.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) def test_make_tables_factory_with_default_query_cls(self): t1, t2 = Query.Tables("abc", "def") q1 = t1.select("1") q2 = t2.select("2") - self.assertIs(q1.dialect, None) - self.assertIs(q2.dialect, None) + self.assertIs(q1.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) + self.assertIs(q2.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) def test_make_tables_factory_with_dialect_query_cls(self): t1, t2 = SQLLiteQuery.Tables("abc", "def") q1 = t1.select("1") q2 = t2.select("2") - self.assertIs(q1.dialect, Dialects.SQLITE) - self.assertIs(q2.dialect, Dialects.SQLITE) + self.assertIs(q1.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) + self.assertIs(q2.QUERY_CLS.SQL_CONTEXT.dialect, Dialects.SQLITE) def test_table_with_bad_query_cls(self): with self.assertRaises(TypeError): diff --git a/tests/test_terms.py b/tests/test_terms.py index 28f8314..bb73f43 100644 --- a/tests/test_terms.py +++ b/tests/test_terms.py @@ -1,6 +1,7 @@ from unittest import TestCase from pypika_tortoise import Field, Query, Table +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import AtTimezone, Parameterizer, ValueWrapper @@ -47,17 +48,21 @@ def test_passes_kwargs_to_field_get_sql(self): self.assertEqual( 'SELECT "customers"."date" AT TIME ZONE \'US/Eastern\' "alias1" ' 'FROM "customers" JOIN "accounts" ON "customers"."account_id"="accounts"."account_id"', - query.get_sql(with_namespace=True), + query.get_sql(DEFAULT_SQL_CONTEXT.copy(with_namespace=True)), ) class ValueWrapperTests(TestCase): def test_allow_parametrize(self): value = ValueWrapper("foo") - self.assertEqual("'foo'", value.get_sql()) + self.assertEqual("'foo'", value.get_sql(DEFAULT_SQL_CONTEXT)) value = ValueWrapper("foo") - self.assertEqual("?", value.get_sql(parameterizer=Parameterizer())) + self.assertEqual( + "?", value.get_sql(DEFAULT_SQL_CONTEXT.copy(parameterizer=Parameterizer())) + ) value = ValueWrapper("foo", allow_parametrize=False) - self.assertEqual("'foo'", value.get_sql(parameterizer=Parameterizer())) + self.assertEqual( + "'foo'", value.get_sql(DEFAULT_SQL_CONTEXT.copy(parameterizer=Parameterizer())) + ) diff --git a/tests/test_tuples.py b/tests/test_tuples.py index 055776b..a4144c0 100644 --- a/tests/test_tuples.py +++ b/tests/test_tuples.py @@ -2,7 +2,7 @@ from pypika_tortoise import Array, Bracket, PostgreSQLQuery, Query, Table, Tables, Tuple from pypika_tortoise.functions import Coalesce, NullIf, Sum -from pypika_tortoise.terms import Field, Parameterizer, Star +from pypika_tortoise.terms import Field, Star class TupleTests(unittest.TestCase): @@ -152,8 +152,7 @@ def test_render_alias_in_array_sql(self): def test_parametrization(self): q = Query.from_(self.table_abc).select(Star()).where(self.table_abc.f == Array(1, 2, 3)) - parameterizer = Parameterizer() - sql = q.get_sql(parameterizer=parameterizer) + sql, _ = q.get_parameterized_sql() self.assertEqual('SELECT * FROM "abc" WHERE "f"=?', sql)