Skip to content

Commit

Permalink
BI-5570: fix .over() serialization (#476) (#501)
Browse files Browse the repository at this point in the history
Failing tests are unrelated to changes in PR
  • Loading branch information
robot-datalens-back authored Jun 24, 2024
1 parent 33ae833 commit 5b6ca32
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

from dl_formula.definitions.base import (
FuncTranslationImplementationBase,
TranslateCallback,
TranslationVariant,
)
from dl_formula.definitions.common import (
TranslateCallback,
over,
)
import dl_formula.definitions.functions_window as base
from dl_formula.translation.context import TranslationCtx
from dl_formula.translation.env import TranslationEnvironment
Expand Down Expand Up @@ -80,8 +83,8 @@ def translation_rank(value: Any, *args: Any) -> SAFunction:
)
order_by_part = base._order_by_from_args(*args) # type: ignore # 2024-01-30 # TODO: Argument 1 to "_order_by_from_args" has incompatible type "*Iterable[TranslationCtx | ClauseElement]"; expected "ClauseElement" [arg-type]

wf_rank = translation_rank(*args).over(partition_by=partition_by, order_by=order_by_part)
wf_total_partition_rows = sa.func.COUNT(1).over(partition_by=partition_by) # ORDER BY is unnecessary
wf_rank = over(translation_rank(*args), partition_by=partition_by, order_by=order_by_part)
wf_total_partition_rows = over(sa.func.COUNT(1), partition_by=partition_by) # ORDER BY is unnecessary

result = (wf_rank - 1) / (wf_total_partition_rows - 1)
return cast(ClauseElement, result)
Expand Down
11 changes: 6 additions & 5 deletions lib/dl_formula/dl_formula/definitions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
ArgFlagDispenser,
ArgTypeMatcher,
)
from dl_formula.definitions.common import (
TransCallResult,
TranslateCallback,
over,
)
from dl_formula.definitions.flags import ContextFlags
from dl_formula.definitions.scope import Scope
from dl_formula.definitions.type_strategy import (
Expand Down Expand Up @@ -90,10 +95,6 @@ def match(self, dialect: DialectCombo) -> bool:
return self.dialects == D.ANY or self.dialects & dialect == dialect


TransCallResult = Union[ClauseElement, nodes.FormulaItem]
TranslateCallback = Callable[[nodes.FormulaItem], TransCallResult]


_TRANS_IMPL_TV = TypeVar("_TRANS_IMPL_TV", bound="FuncTranslationImplementationBase")


Expand Down Expand Up @@ -273,7 +274,7 @@ def translate(
# Note that an `Over` object cannot be simply reconstructed from its parts
# as attributes of the existing `Over` instance.
# RANGE_UNBOUNDED has to be replaced with None and RANGE_CURRENT with 0.
return func_part.over(partition_by=partition_by, order_by=order_by_part, range_=range_part, rows=rows_part)
return over(func_part, partition_by=partition_by, order_by=order_by_part, range_=range_part, rows=rows_part)


_TRANS_VAR_TV = TypeVar("_TRANS_VAR_TV", bound="TranslationVariant")
Expand Down
79 changes: 70 additions & 9 deletions lib/dl_formula/dl_formula/definitions/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from __future__ import annotations

from typing import (
TYPE_CHECKING,
Callable,
Optional,
Tuple,
TypeVar,
Union,
no_type_check,
)

import sqlalchemy as sa
from sqlalchemy.sql.elements import (
RANGE_CURRENT,
RANGE_UNBOUNDED,
Case,
ClauseElement,
ColumnElement,
Over,
UnaryExpression,
WithinGroup,
)
Expand All @@ -18,17 +26,9 @@
exc,
nodes,
)
from dl_formula.definitions.base import TransCallResult
from dl_formula.shortcuts import n


if TYPE_CHECKING:
from sqlalchemy.sql.elements import (
ClauseElement,
ColumnElement,
)


class _TextClauseHack(sa.sql.elements.TextClause):
def __bool__(self) -> bool: # type: ignore # TODO: bug in sqlalchemy stubs? Return type "bool" of "__bool__" incompatible with return type "None" in supertype "ClauseElement" [override]
# Possibly a bug in sqlalchemy 1.4: query caching fails as, normally,
Expand All @@ -42,6 +42,9 @@ def raw_sql(sql_text: str) -> _TextClauseHack:
return _TextClauseHack(sql_text)


TransCallResult = Union[ClauseElement, nodes.FormulaItem]
TranslateCallback = Callable[[nodes.FormulaItem], TransCallResult]

_BINARY_CHAIN_TV = TypeVar("_BINARY_CHAIN_TV", bound=TransCallResult)


Expand Down Expand Up @@ -104,3 +107,61 @@ def __reduce__(self) -> tuple:

def within_group(clause_el: ClauseElement, *order_by: ClauseElement) -> _PatchedWithinGroup:
return _PatchedWithinGroup(clause_el, *order_by)


class _PatchedOver(Over):
"""Backport for https://github.com/sqlalchemy/sqlalchemy/issues/11422"""

@no_type_check # keeping the original typing
def _interpret_range(self, range_):
"""
Mostly copied from
https://github.com/sqlalchemy/sqlalchemy/blob/rel_1_4/lib/sqlalchemy/sql/elements.py#L4229-L4265
except where noted
"""
if not isinstance(range_, tuple) or len(range_) != 2:
raise sa.exc.ArgumentError("2-tuple expected for range/rows") # non-local import

if range_[0] is None:
lower = RANGE_UNBOUNDED
elif range_[0] is RANGE_UNBOUNDED or range_[0] is RANGE_CURRENT: # fixes issues#11422
lower = range_[0]
else:
try:
lower = int(range_[0])
except ValueError as err:
sa._util.raise_( # non-local import
sa.exc.ArgumentError("Integer or None expected for range value"), # non-local import
replace_context=err,
)
else:
if lower == 0:
lower = RANGE_CURRENT

if range_[1] is None:
upper = RANGE_UNBOUNDED
elif range_[1] is RANGE_UNBOUNDED or range_[1] is RANGE_CURRENT: # fixes issues#11422
upper = range_[1]
else:
try:
upper = int(range_[1])
except ValueError as err:
sa._util.raise_( # non-local import
sa.exc.ArgumentError("Integer or None expected for range value"), # non-local import
replace_context=err,
)
else:
if upper == 0:
upper = RANGE_CURRENT

return lower, upper


def over(
clause_el: ClauseElement,
partition_by: Optional[ClauseElement] = None,
order_by: Optional[ClauseElement] = None,
rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
) -> _PatchedOver:
return _PatchedOver(clause_el, partition_by=partition_by, order_by=order_by, rows=rows, range_=range_)

0 comments on commit 5b6ca32

Please sign in to comment.