diff --git a/lib/dl_connector_clickhouse/dl_connector_clickhouse/formula/definitions/functions_window.py b/lib/dl_connector_clickhouse/dl_connector_clickhouse/formula/definitions/functions_window.py index 385f1c63f..10b332594 100644 --- a/lib/dl_connector_clickhouse/dl_connector_clickhouse/formula/definitions/functions_window.py +++ b/lib/dl_connector_clickhouse/dl_connector_clickhouse/formula/definitions/functions_window.py @@ -13,9 +13,12 @@ from dl_formula.definitions.base import ( FuncTranslationImplementationBase, - TranslateCallback, TranslationVariant, ) +from dl_formula.definitions.common import ( + TranslateCallback, + over, +) import dl_formula.definitions.functions_window as base from dl_formula.translation.context import TranslationCtx from dl_formula.translation.env import TranslationEnvironment @@ -80,8 +83,8 @@ def translation_rank(value: Any, *args: Any) -> SAFunction: ) order_by_part = base._order_by_from_args(*args) # type: ignore # 2024-01-30 # TODO: Argument 1 to "_order_by_from_args" has incompatible type "*Iterable[TranslationCtx | ClauseElement]"; expected "ClauseElement" [arg-type] - wf_rank = translation_rank(*args).over(partition_by=partition_by, order_by=order_by_part) - wf_total_partition_rows = sa.func.COUNT(1).over(partition_by=partition_by) # ORDER BY is unnecessary + wf_rank = over(translation_rank(*args), partition_by=partition_by, order_by=order_by_part) + wf_total_partition_rows = over(sa.func.COUNT(1), partition_by=partition_by) # ORDER BY is unnecessary result = (wf_rank - 1) / (wf_total_partition_rows - 1) return cast(ClauseElement, result) diff --git a/lib/dl_formula/dl_formula/definitions/base.py b/lib/dl_formula/dl_formula/definitions/base.py index 843e68d37..64cfda504 100644 --- a/lib/dl_formula/dl_formula/definitions/base.py +++ b/lib/dl_formula/dl_formula/definitions/base.py @@ -39,6 +39,11 @@ ArgFlagDispenser, ArgTypeMatcher, ) +from dl_formula.definitions.common import ( + TransCallResult, + TranslateCallback, + over, +) from dl_formula.definitions.flags import ContextFlags from dl_formula.definitions.scope import Scope from dl_formula.definitions.type_strategy import ( @@ -90,10 +95,6 @@ def match(self, dialect: DialectCombo) -> bool: return self.dialects == D.ANY or self.dialects & dialect == dialect -TransCallResult = Union[ClauseElement, nodes.FormulaItem] -TranslateCallback = Callable[[nodes.FormulaItem], TransCallResult] - - _TRANS_IMPL_TV = TypeVar("_TRANS_IMPL_TV", bound="FuncTranslationImplementationBase") @@ -273,7 +274,7 @@ def translate( # Note that an `Over` object cannot be simply reconstructed from its parts # as attributes of the existing `Over` instance. # RANGE_UNBOUNDED has to be replaced with None and RANGE_CURRENT with 0. - return func_part.over(partition_by=partition_by, order_by=order_by_part, range_=range_part, rows=rows_part) + return over(func_part, partition_by=partition_by, order_by=order_by_part, range_=range_part, rows=rows_part) _TRANS_VAR_TV = TypeVar("_TRANS_VAR_TV", bound="TranslationVariant") diff --git a/lib/dl_formula/dl_formula/definitions/common.py b/lib/dl_formula/dl_formula/definitions/common.py index 7a22438b3..8610d8502 100644 --- a/lib/dl_formula/dl_formula/definitions/common.py +++ b/lib/dl_formula/dl_formula/definitions/common.py @@ -1,14 +1,22 @@ from __future__ import annotations from typing import ( - TYPE_CHECKING, Callable, + Optional, + Tuple, TypeVar, + Union, + no_type_check, ) import sqlalchemy as sa from sqlalchemy.sql.elements import ( + RANGE_CURRENT, + RANGE_UNBOUNDED, Case, + ClauseElement, + ColumnElement, + Over, UnaryExpression, WithinGroup, ) @@ -18,17 +26,9 @@ exc, nodes, ) -from dl_formula.definitions.base import TransCallResult from dl_formula.shortcuts import n -if TYPE_CHECKING: - from sqlalchemy.sql.elements import ( - ClauseElement, - ColumnElement, - ) - - class _TextClauseHack(sa.sql.elements.TextClause): def __bool__(self) -> bool: # type: ignore # TODO: bug in sqlalchemy stubs? Return type "bool" of "__bool__" incompatible with return type "None" in supertype "ClauseElement" [override] # Possibly a bug in sqlalchemy 1.4: query caching fails as, normally, @@ -42,6 +42,9 @@ def raw_sql(sql_text: str) -> _TextClauseHack: return _TextClauseHack(sql_text) +TransCallResult = Union[ClauseElement, nodes.FormulaItem] +TranslateCallback = Callable[[nodes.FormulaItem], TransCallResult] + _BINARY_CHAIN_TV = TypeVar("_BINARY_CHAIN_TV", bound=TransCallResult) @@ -104,3 +107,61 @@ def __reduce__(self) -> tuple: def within_group(clause_el: ClauseElement, *order_by: ClauseElement) -> _PatchedWithinGroup: return _PatchedWithinGroup(clause_el, *order_by) + + +class _PatchedOver(Over): + """Backport for https://github.com/sqlalchemy/sqlalchemy/issues/11422""" + + @no_type_check # keeping the original typing + def _interpret_range(self, range_): + """ + Mostly copied from + https://github.com/sqlalchemy/sqlalchemy/blob/rel_1_4/lib/sqlalchemy/sql/elements.py#L4229-L4265 + except where noted + """ + if not isinstance(range_, tuple) or len(range_) != 2: + raise sa.exc.ArgumentError("2-tuple expected for range/rows") # non-local import + + if range_[0] is None: + lower = RANGE_UNBOUNDED + elif range_[0] is RANGE_UNBOUNDED or range_[0] is RANGE_CURRENT: # fixes issues#11422 + lower = range_[0] + else: + try: + lower = int(range_[0]) + except ValueError as err: + sa._util.raise_( # non-local import + sa.exc.ArgumentError("Integer or None expected for range value"), # non-local import + replace_context=err, + ) + else: + if lower == 0: + lower = RANGE_CURRENT + + if range_[1] is None: + upper = RANGE_UNBOUNDED + elif range_[1] is RANGE_UNBOUNDED or range_[1] is RANGE_CURRENT: # fixes issues#11422 + upper = range_[1] + else: + try: + upper = int(range_[1]) + except ValueError as err: + sa._util.raise_( # non-local import + sa.exc.ArgumentError("Integer or None expected for range value"), # non-local import + replace_context=err, + ) + else: + if upper == 0: + upper = RANGE_CURRENT + + return lower, upper + + +def over( + clause_el: ClauseElement, + partition_by: Optional[ClauseElement] = None, + order_by: Optional[ClauseElement] = None, + rows: Optional[Tuple[Optional[int], Optional[int]]] = None, + range_: Optional[Tuple[Optional[int], Optional[int]]] = None, +) -> _PatchedOver: + return _PatchedOver(clause_el, partition_by=partition_by, order_by=order_by, rows=rows, range_=range_)