Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Enable join using common columns #198

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 79 additions & 12 deletions greenplumpython/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def join(
other: "DataFrame",
how: Literal["", "left", "right", "outer", "inner", "cross"] = "",
cond: Optional[Callable[["DataFrame", "DataFrame"], Expr]] = None,
on: Optional[Union[str, Iterable[str]]] = None,
on: Iterable[str] = None,
self_columns: Union[Dict[str, Optional[str]], Set[str]] = {"*"},
other_columns: Union[Dict[str, Optional[str]], Set[str]] = {"*"},
) -> "DataFrame":
Expand Down Expand Up @@ -583,18 +583,18 @@ def join(
... age_rows, column_names=["name", "age"], db=db)
>>> result = student.join(
... student,
... on="age",
... self_columns={"*"},
... other_columns={"name": "name_2"})
... on=["age"],
... self_columns={"name": "name", "age": "age_1"},
... other_columns={"name": "name_2", "age": "age_2"})
>>> result
----------------------
name | age | name_2
-------+-----+--------
alice | 18 | alice
bob | 19 | carol
bob | 19 | bob
carol | 19 | carol
carol | 19 | bob
age | name | name_2
-----+-------+--------
18 | alice | alice
19 | bob | carol
19 | bob | bob
19 | carol | carol
19 | carol | bob
----------------------
(5 rows)
"""
Expand Down Expand Up @@ -631,13 +631,80 @@ def bind(t: DataFrame, columns: Union[Dict[str, Optional[str]], Set[str]]) -> Li
)
# USING clause in SQL uses argument `on`.
sql_using_clause = f"USING ({join_column_names})" if join_column_names is not None else ""
return DataFrame(

if on is None:
return DataFrame(
f"""
SELECT {",".join(target_list)}
FROM {self._name} {how} JOIN {other_clause} {sql_on_clause} {sql_using_clause}
""",
parents=[self, other],
)

def bind_using(
t: DataFrame,
columns: Union[Dict[str, Optional[str]], Set[str]],
on: Iterable[str],
suffix: str,
) -> List[str]:
target_list: List[str] = []
for k in columns:
col: Column = t[k]
v = columns[k] if isinstance(columns, dict) else (k + suffix) if k in on else None
target_list.append(col._serialize() + (f' AS "{v}"' if v is not None else ""))
return target_list

self_target_list = (
bind_using(self, self_columns, on, "_l")
if isinstance(self_columns, set)
else bind(self, self_columns)
)
other_target_list = (
bind_using(other_temp, other_columns, on, "_r")
if isinstance(other_columns, set)
else bind(other_temp, other_columns)
)
target_list = self_target_list + other_target_list

join_dataframe = DataFrame(
f"""
SELECT {",".join(target_list)}
FROM {self._name} {how} JOIN {other_clause} {sql_on_clause} {sql_using_clause}
""",
parents=[self, other],
)
coalesce_target_list = []
if self_columns and other_columns:
for k in on:
s_v = self_columns[k] if isinstance(self_columns, dict) else (k + "_l")
o_v = other_columns[k] if isinstance(other_columns, dict) else (k + "_r")
coalesce_target_list.append(f"COALESCE({s_v},{o_v}) AS {k}")

join_df = DataFrame(
f"""
SELECT * {("," + ",".join(coalesce_target_list)) if coalesce_target_list != [] else ""}
FROM {join_dataframe._name}
""",
parents=[join_dataframe],
)

self_columns_set = (
self_columns
if isinstance(self_columns, set)
else set([k if k in on else v for k, v in self_columns.items()])
)
other_columns_set = (
other_columns
if isinstance(other_columns, set)
else set([k if k in on else v for k, v in other_columns.items()])
)
return DataFrame(
f"""
SELECT {",".join(sorted(self_columns_set | other_columns_set))}
FROM {join_df._name}
""",
parents=[join_df],
)

inner_join = partialmethod(join, how="INNER")
"""
Expand Down
90 changes: 56 additions & 34 deletions tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def test_join_same_column_using(db: gp.Database):
rows = [(1,), (2,), (3,)]
t1 = db.create_dataframe(rows=rows, column_names=["id"])
t2 = db.create_dataframe(rows=rows, column_names=["id"])
ret = t1.join(t2, on=["id"], self_columns={"id": "t1_id"}, other_columns={"id": "t2_id"})
assert sorted(next(iter(ret)).keys()) == sorted(["t1_id", "t2_id"])
ret = t1.join(t2, on=["id"], self_columns={"id"}, other_columns={"id"})
assert list(next(iter(ret)).keys()) == ["id"]


def test_join_same_column_names(db: gp.Database):
Expand All @@ -121,21 +121,21 @@ def test_join_on_multi_columns(db: gp.Database):
rows = [(1, 1), (2, 1), (3, 1)]
t1 = db.create_dataframe(rows=rows, column_names=["id", "n"])
t2 = db.create_dataframe(rows=rows, column_names=["id", "n"])
ret = t1.join(t2, on=["id", "n"], other_columns={})
print(ret)
ret = t1.join(t2, on=["id", "n"], self_columns={"id", "n"}, other_columns={"id", "n"})
assert sorted(next(iter(ret)).keys()) == sorted(["id", "n"])


def test_dataframe_inner_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.DataFrame):
ret: gp.DataFrame = zoo_1.join(
zoo_2,
on=["animal"],
self_columns={"animal": "zoo1_animal", "id": "zoo1_id"},
other_columns={"animal": "zoo2_animal", "id": "zoo2_id"},
self_columns={"animal": "animal_l", "id": "id_zoo1"},
other_columns={"animal": "animal_r", "id": "id_zoo2"},
)
assert len(list(ret)) == 2
assert sorted(next(iter(ret)).keys()) == sorted(["animal", "id_zoo1", "id_zoo2"])
for row in ret:
assert row["zoo1_animal"] == row["zoo2_animal"]
assert row["zoo1_animal"] == "Lion" or row["zoo1_animal"] == "Tiger"
assert row["animal"] == "Lion" or row["animal"] == "Tiger"


def test_dataframe_left_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.DataFrame):
Expand All @@ -147,10 +147,7 @@ def test_dataframe_left_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Dat
)
assert len(list(ret)) == 4
for row in ret:
if row["zoo1_animal"] == "Lion" or row["zoo1_animal"] == "Tiger":
assert row["zoo1_animal"] == row["zoo2_animal"]
else:
assert row["zoo2_animal"] is None
if row["animal"] != "Lion" and row["animal"] != "Tiger":
assert row["zoo2_id"] is None


Expand All @@ -163,10 +160,7 @@ def test_dataframe_right_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Da
)
assert len(list(ret)) == 4
for row in ret:
if row["zoo2_animal"] == "Lion" or row["zoo2_animal"] == "Tiger":
assert row["zoo1_animal"] == row["zoo2_animal"]
else:
assert row["zoo1_animal"] is None
if row["animal"] != "Lion" and row["animal"] != "Tiger":
assert row["zoo1_id"] is None


Expand All @@ -179,17 +173,42 @@ def test_dataframe_full_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Dat
)
assert len(list(ret)) == 6
for row in ret:
if row["zoo2_animal"] == "Lion" or row["zoo2_animal"] == "Tiger":
assert row["zoo1_animal"] == row["zoo2_animal"]
else:
assert (row["zoo1_animal"] is None and row["zoo2_animal"] is not None) or (
row["zoo1_animal"] is not None and row["zoo2_animal"] is None
)
if row["animal"] != "Lion" and row["animal"] != "Tiger":
assert (row["zoo1_id"] is None and row["zoo2_id"] is not None) or (
row["zoo1_id"] is not None and row["zoo2_id"] is None
)


def test_dataframe_full_join_with_empty(db: gp.Database):
# fmt: off
rows1 = [(1, 100,), (2, 200,), (3, 300,), (4, 400,)]
rows2 = [(3, 300, 3000,), (4, 400, 4000,), (5, 500, 5000,), (6, 600, 6000)]
# fmt: on
l_df = db.create_dataframe(rows=rows1, column_names=["a", "b"])
r_df = db.create_dataframe(rows=rows2, column_names=["a", "b", "c"])
ret = l_df.full_join(
r_df,
self_columns={"a", "b"},
other_columns={"a", "b", "c"},
on=["a", "b"],
).order_by("a")[:]
assert len(list(ret)) == 6
expected = (
"----------------\n"
" a | b | c \n"
"---+-----+------\n"
" 1 | 100 | \n"
" 2 | 200 | \n"
" 3 | 300 | 3000 \n"
" 4 | 400 | 4000 \n"
" 5 | 500 | 5000 \n"
" 6 | 600 | 6000 \n"
"----------------\n"
"(6 rows)\n"
)
assert str(ret) == expected


def test_join_natural(db: gp.Database):
# fmt: off
rows1 = [("Smart Phone", 1,), ("Laptop", 2,), ("DataFramet", 3,)]
Expand All @@ -202,8 +221,8 @@ def test_join_natural(db: gp.Database):
ret = categories.join(
products,
on=["category_id"],
self_columns={"category_name", "category_id"},
other_columns={"product_name"},
self_columns={"category_id", "category_name"},
other_columns={"category_id", "product_name"},
)
assert len(list(ret)) == 6
assert sorted(next(iter(ret)).keys()) == sorted(
Expand Down Expand Up @@ -246,8 +265,6 @@ def test_dataframe_self_join(db: gp.Database, zoo_1: gp.DataFrame):
other_columns={"animal": "zoo2_animal", "id": "zoo2_id"},
)
assert len(list(ret)) == 4
for row in ret:
assert row["zoo1_animal"] == row["zoo2_animal"]


def test_dataframe_self_join_cond(db: gp.Database, zoo_1: gp.DataFrame):
Expand All @@ -271,20 +288,17 @@ def test_dataframe_join_save(db: gp.Database, zoo_1: gp.DataFrame):
)
t_join.save_as(
"dataframe_join",
column_names=["zoo1_animal", "zoo1_id", "zoo2_animal", "zoo2_id"],
column_names=["animal", "zoo1_id", "zoo2_id"],
temp=True,
)
t_join_reload = gp.DataFrame.from_table("dataframe_join", db=db)
assert sorted(next(iter(t_join_reload)).keys()) == sorted(
[
"zoo1_animal",
"animal",
"zoo1_id",
"zoo2_animal",
"zoo2_id",
]
)
for row in t_join_reload:
assert row["zoo1_animal"] == row["zoo2_animal"]


def test_dataframe_join_ine(db: gp.Database):
Expand All @@ -307,11 +321,19 @@ def test_dataframe_multiple_self_join(db: gp.Database, zoo_1: gp.DataFrame):
)
ret = t_join.join(
zoo_1,
cond=lambda s, o: s["zoo1_animal"] == o["animal"],
on=["animal"],
self_columns={"animal", "zoo1_id", "zoo2_id"},
other_columns={"animal", "id"},
)
assert len(list(ret)) == 4
for row in ret:
assert row["zoo2_animal"] == row["animal"]
assert sorted(next(iter(ret)).keys()) == sorted(
[
"animal",
"id",
"zoo1_id",
"zoo2_id",
]
)


# This test case is to guarantee that the CTEs are generated in the reversed
Expand Down
3 changes: 2 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_schema_self_join_on(db: gp.Database, t: gp.DataFrame):
ret: gp.DataFrame = t.join(
t,
on=["id"],
other_columns={"id": "id_1"},
self_columns={"id"},
other_columns={"id"},
)
assert len(list(ret)) == 10

Expand Down