Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Commit

Permalink
Modify testcase for func in where
Browse files Browse the repository at this point in the history
  • Loading branch information
ruxuez committed Aug 4, 2023
1 parent 81b0948 commit f48dbaf
Showing 1 changed file with 4 additions and 27 deletions.
31 changes: 4 additions & 27 deletions tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,30 +829,7 @@ def test_func_in_binary_expr(db: gp.Database):
assert row["val"] == (1 + 2) + (1 + 1)


def test_func_after_where(db: gp.Database):
# fmt: off
rows = [(i, i,) for i in range(0, 10)]
# fmt: on
df = db.create_dataframe(rows=rows, column_names=["a", "b"])
result = df.where(lambda t: t["a"] < 5).assign(val=lambda t: add_two(t["a"]) + add_one(t["b"]))
for i, row in enumerate(result):
assert row["val"] == (i + 2) + (i + 1)


def test_func_after_select(db: gp.Database):
# fmt: off
rows = [(i, i,) for i in range(0, 10)]
# fmt: on
df = db.create_dataframe(rows=rows, column_names=["a", "b"])
result = df[lambda t: t["a"] < 5].assign(val=lambda t: add_two(add_one(t["a"]) + t["b"]))
for i, row in enumerate(result):
assert row["val"] == (i + 1) + i + 2


def test_operator_after_select(db: gp.Database):
rows = [(i,) for i in range(0, 10)]
exponential = gp.operator("^")
df = db.create_dataframe(rows=rows, column_names=["a"])
result = df.assign(val=lambda t: exponential(add_one(t["a"]), 2))
for i, row in enumerate(result):
assert row["val"] == (i + 1) * (i + 1)
def test_func_in_where(db: gp.Database):
df = db.create_dataframe(columns={"a": [1]})
result = df.where(lambda t: add_two(t["a"]) < 5)
assert len(list(result)) == 1

0 comments on commit f48dbaf

Please sign in to comment.