Skip to content

Commit

Permalink
Merge pull request #18 from waketzheng/simple-refactor
Browse files Browse the repository at this point in the history
Simple refactor and improve type hints
  • Loading branch information
waketzheng authored Nov 30, 2024
2 parents a303518 + 74b396b commit 413d7ed
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 120 deletions.
7 changes: 2 additions & 5 deletions pypika/dialects/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def top(self, value: str | int) -> MSSQLQueryBuilder: # type:ignore[return]
try:
self._top = int(value)
except ValueError:
raise QueryException("TOP value must be an integer")
raise QueryException("TOP value must be an integer") from None

@builder
def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return]
Expand Down Expand Up @@ -76,10 +76,7 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str:
return super().get_sql(*args, **kwargs)

def _top_sql(self) -> str:
if self._top:
return "TOP ({}) ".format(self._top)
else:
return ""
return "TOP ({}) ".format(self._top) if self._top else ""

def _select_sql(self, **kwargs: Any) -> str:
return "SELECT {distinct}{top}{select}".format(
Expand Down
16 changes: 7 additions & 9 deletions pypika/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ def _on_conflict_sql(self, **kwargs: Any) -> str:
def get_sql(self, **kwargs: Any) -> str: # type:ignore[override]
self._set_kwargs_defaults(kwargs)
querystring = super().get_sql(**kwargs)
if querystring:
if self._update_table:
if self._orderbys:
querystring += self._orderby_sql(**kwargs)
if self._limit:
querystring += self._limit_sql()
if querystring and self._update_table:
if self._orderbys:
querystring += self._orderby_sql(**kwargs)
if self._limit:
querystring += self._limit_sql()
return querystring

def _on_conflict_action_sql(self, **kwargs: Any) -> str:
Expand All @@ -88,10 +87,9 @@ def _on_conflict_action_sql(self, **kwargs: Any) -> str:
)
else:
updates.append(
"{field}={alias_quote_char}{alias}{alias_quote_char}.{value}".format(
alias_quote_char=self.QUOTE_CHAR,
"{field}={alias}.{value}".format(
field=field.get_sql(**kwargs),
alias=self.alias,
alias=format_quotes(self.alias, self.QUOTE_CHAR),
value=field.get_sql(**kwargs),
)
)
Expand Down
13 changes: 10 additions & 3 deletions pypika/dialects/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from __future__ import annotations

import itertools
import sys
from copy import copy
from typing import Any
from typing import TYPE_CHECKING, Any

from pypika.enums import Dialects
from pypika.exceptions import QueryException
from pypika.queries import Query, QueryBuilder
from pypika.terms import ArithmeticExpression, Field, Function, Star, Term
from pypika.utils import builder

if TYPE_CHECKING:
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class PostgreSQLQuery(Query):
"""
Expand All @@ -32,8 +39,8 @@ def __init__(self, **kwargs: Any) -> None:

self._distinct_on: list[Field | Term] = []

def __copy__(self) -> PostgreSQLQueryBuilder:
newone: PostgreSQLQueryBuilder = super().__copy__() # type:ignore[assignment]
def __copy__(self) -> "Self":
newone = super().__copy__()
newone._returns = copy(self._returns)
newone._on_conflict_do_updates = copy(self._on_conflict_do_updates)
return newone
Expand Down
Loading

0 comments on commit 413d7ed

Please sign in to comment.