Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-commit CI to PyDough #252

Merged
merged 7 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
ci:
autoupdate_schedule: monthly

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.7
Expand Down
2 changes: 1 addition & 1 deletion pydough/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
"explain",
"explain_structure",
"explain_term",
"get_logger",
"init_pydough_context",
"parse_json_metadata_from_file",
"to_df",
"to_sql",
"get_logger"
]

from .configs import PyDoughSession
Expand Down
4 changes: 1 addition & 3 deletions pydough/logger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Module of PyDough dealing with logging across the library
"""

__all__ = [
"get_logger"
]
__all__ = ["get_logger"]

from .logger import get_logger
8 changes: 5 additions & 3 deletions pydough/sqlglot/execute_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def convert_relation_to_sql(
relational: RelationalRoot,
dialect: SQLGlotDialect,
bindings: SqlGlotTransformBindings,
pretty_print_sql: bool = False
pretty_print_sql: bool = False,
) -> str:
"""
Convert the given relational tree to a SQL string using the given dialect.
Expand All @@ -44,7 +44,7 @@ def convert_relation_to_sql(
glot_expr: SQLGlotExpression = SQLGlotRelationalVisitor(
dialect, bindings
).relational_to_sqlglot(relational)
return glot_expr.sql(dialect,pretty=pretty_print_sql)
return glot_expr.sql(dialect, pretty=pretty_print_sql)


def convert_dialect_to_sqlglot(dialect: DatabaseDialect) -> SQLGlotDialect:
Expand Down Expand Up @@ -91,7 +91,9 @@ def execute_df(
pretty_print_sql: bool = False
if display_sql:
pretty_print_sql = True
sql: str = convert_relation_to_sql(relational, sqlglot_dialect, bindings,pretty_print_sql)
sql: str = convert_relation_to_sql(
relational, sqlglot_dialect, bindings, pretty_print_sql
)
if display_sql:
pyd_logger = get_logger(__name__)
pyd_logger.info(f"SQL query:\n {sql}")
Expand Down
6 changes: 3 additions & 3 deletions pydough/sqlglot/sqlglot_relational_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _merge_selects(
new_columns: list[SQLGlotExpression],
orig_select: Select,
deps: set[Identifier],
sort: bool = True
sort: bool = True,
) -> Select:
"""
Attempt to merge a new select statement with an existing one.
Expand All @@ -195,7 +195,7 @@ def _merge_selects(
new_columns, orig_select.expressions, deps
)
if sort:
old_exprs = sorted(old_exprs,key=repr)
old_exprs = sorted(old_exprs, key=repr)
orig_select.set("expressions", old_exprs)
if new_exprs is None:
return orig_select
Expand Down Expand Up @@ -289,7 +289,7 @@ def _build_subquery(
Select: A select statement representing the subquery.
"""
if sort:
column_exprs = sorted(column_exprs,key=repr)
column_exprs = sorted(column_exprs, key=repr)
return (
Select().select(*column_exprs).from_(Subquery(this=input_expr, alias=alias))
)
Expand Down
38 changes: 20 additions & 18 deletions pydough/sqlglot/transform_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def convert_ndistinct(
this=sqlglot_expressions.Distinct(expressions=[column])
)


def create_convert_time_unit_function(unit: str):
"""
Creates a function that extracts a specific time unit
Expand All @@ -517,6 +518,7 @@ def create_convert_time_unit_function(unit: str):
A function that can convert operands into a SQLGlot expression matching
the functionality of `EXTRACT(unit FROM expression)`.
"""

def convert_time_unit(
raw_args: Sequence[RelationalExpression] | None,
sql_glot_args: Sequence[SQLGlotExpression],
Expand All @@ -537,35 +539,35 @@ def convert_time_unit(
from the first operand.
"""
return sqlglot_expressions.Extract(
this=sqlglot_expressions.Var(this=unit),
expression=sql_glot_args[0]
this=sqlglot_expressions.Var(this=unit), expression=sql_glot_args[0]
)

return convert_time_unit


def convert_sqrt(
raw_args: Sequence[RelationalExpression] | None,
sql_glot_args: Sequence[SQLGlotExpression],
) -> SQLGlotExpression:
raw_args: Sequence[RelationalExpression] | None,
sql_glot_args: Sequence[SQLGlotExpression],
) -> SQLGlotExpression:
"""
Support for getting the square root of the operand.
Support for getting the square root of the operand.

Args:
`raw_args`: The operands passed to the function before they were converted to
SQLGlot expressions. (Not actively used in this implementation.)
`sql_glot_args`: The operands passed to the function after they were converted
to SQLGlot expressions.
Args:
`raw_args`: The operands passed to the function before they were converted to
SQLGlot expressions. (Not actively used in this implementation.)
`sql_glot_args`: The operands passed to the function after they were converted
to SQLGlot expressions.

Returns:
The SQLGlot expression matching the functionality of
`POWER(x,0.5)`,i.e the square root.
Returns:
The SQLGlot expression matching the functionality of
`POWER(x,0.5)`,i.e the square root.
"""

return sqlglot_expressions.Pow(
this=sql_glot_args[0],
expression=sqlglot_expressions.Literal.number(0.5)
this=sql_glot_args[0], expression=sqlglot_expressions.Literal.number(0.5)
)


class SqlGlotTransformBindings:
"""
Binding infrastructure used to associate PyDough operators with a procedure
Expand Down Expand Up @@ -754,8 +756,8 @@ def add_builtin_bindings(self) -> None:
self.bind_binop(pydop.NEQ, sqlglot_expressions.NEQ)
self.bind_binop(pydop.BAN, sqlglot_expressions.And)
self.bind_binop(pydop.BOR, sqlglot_expressions.Or)
self.bind_binop(pydop.POW,sqlglot_expressions.Pow)
self.bind_binop(pydop.POWER,sqlglot_expressions.Pow)
self.bind_binop(pydop.POW, sqlglot_expressions.Pow)
self.bind_binop(pydop.POWER, sqlglot_expressions.Pow)
self.bindings[pydop.SQRT] = convert_sqrt

# Unary operators
Expand Down
20 changes: 12 additions & 8 deletions tests/simple_pydough_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,18 @@ def hour_minute_day():
transaction timestamps for specific ticker symbols ("AAPL","GOOGL","NFLX"),
ordered by transaction ID in ascending order.
"""
return Transactions(
transaction_id, HOUR(date_time), MINUTE(date_time), SECOND(date_time)
).WHERE(
ISIN(ticker.symbol,("AAPL", "GOOGL", "NFLX"))
).ORDER_BY(
transaction_id.ASC()
return (
Transactions(
transaction_id, HOUR(date_time), MINUTE(date_time), SECOND(date_time)
)
.WHERE(ISIN(ticker.symbol, ("AAPL", "GOOGL", "NFLX")))
.ORDER_BY(transaction_id.ASC())
)


def exponentiation():
return DailyPrices(low_square = low ** 2, low_sqrt = SQRT(low),
low_cbrt = POWER(low, 1/3), ).TOP_K(10, by=low_square.ASC())
return DailyPrices(
low_square=low**2,
low_sqrt=SQRT(low),
low_cbrt=POWER(low, 1 / 3),
).TOP_K(10, by=low_square.ASC())
71 changes: 56 additions & 15 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,15 +754,26 @@ def test_pipeline_e2e_errors(
lambda: pd.DataFrame(
{
"transaction_id": [
"TX001", "TX005", "TX011", "TX015", "TX021", "TX025",
"TX031", "TX033", "TX035", "TX044", "TX045", "TX049",
"TX051", "TX055"
"TX001",
"TX005",
"TX011",
"TX015",
"TX021",
"TX025",
"TX031",
"TX033",
"TX035",
"TX044",
"TX045",
"TX049",
"TX051",
"TX055",
],
"_expr0": [9, 12, 9, 12, 9, 12, 0, 0, 0, 10, 10, 16, 0, 0],
"_expr1": [30, 30, 30, 30, 30, 30, 0, 0, 0, 0, 30, 0, 0, 0],
"_expr2": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
}
)
),
),
id="broker_basic1",
),
Expand All @@ -772,22 +783,52 @@ def test_pipeline_e2e_errors(
"Broker",
lambda: pd.DataFrame(
{
"low_square" : [6642.2500, 6740.4100, 6839.2900, 6938.8900, 7039.2100,
7140.2500, 7242.0100, 16576.5625, 16900.0000, 17292.2500],
"low_sqrt" : [9.027735, 9.060905, 9.093954, 9.126883, 9.159694,
9.192388, 9.224966, 11.346806, 11.401754, 11.467345],
"low_cbrt" : [4.335633, 4.346247, 4.356809, 4.367320, 4.377781,
4.388191, 4.398553, 5.049508, 5.065797, 5.085206]
"low_square": [
6642.2500,
6740.4100,
6839.2900,
6938.8900,
7039.2100,
7140.2500,
7242.0100,
16576.5625,
16900.0000,
17292.2500,
],
"low_sqrt": [
9.027735,
9.060905,
9.093954,
9.126883,
9.159694,
9.192388,
9.224966,
11.346806,
11.401754,
11.467345,
],
"low_cbrt": [
4.335633,
4.346247,
4.356809,
4.367320,
4.377781,
4.388191,
4.398553,
5.049508,
5.065797,
5.085206,
],
}
)
),
),
id="exponentiation",
),
],
],
)
def custom_defog_test_data(
request,
) -> tuple[Callable[[], UnqualifiedNode],str,pd.DataFrame]:
) -> tuple[Callable[[], UnqualifiedNode], str, pd.DataFrame]:
"""
Test data for test_defog_e2e. Returns a tuple of the following
arguments:
Expand All @@ -801,7 +842,7 @@ def custom_defog_test_data(

@pytest.mark.execute
def test_defog_e2e_with_custom_data(
custom_defog_test_data: tuple[Callable[[], UnqualifiedNode],str,pd.DataFrame],
custom_defog_test_data: tuple[Callable[[], UnqualifiedNode], str, pd.DataFrame],
defog_graphs: graph_fetcher,
sqlite_defog_connection: DatabaseContext,
):
Expand All @@ -810,7 +851,7 @@ def test_defog_e2e_with_custom_data(
comparing against the result of running the reference SQL query text on the
same database connector.
"""
unqualified_impl, graph_name ,answer_impl = custom_defog_test_data
unqualified_impl, graph_name, answer_impl = custom_defog_test_data
graph: GraphMetadata = defog_graphs(graph_name)
root: UnqualifiedNode = init_pydough_context(graph)(unqualified_impl)()
result: pd.DataFrame = to_df(root, metadata=graph, database=sqlite_defog_connection)
Expand Down
1 change: 1 addition & 0 deletions tests/test_pydough_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def test_pydough_to_sql(
expected_sql = expected_sql.strip()
assert actual_sql == expected_sql


@pytest.mark.parametrize(
"pydough_code,expected_sql,graph_name",
[
Expand Down