diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 918576b3..3a4cbea0 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -539,7 +539,7 @@ def join( other: "DataFrame", how: Literal["", "left", "right", "outer", "inner", "cross"] = "", cond: Optional[Callable[["DataFrame", "DataFrame"], Expr]] = None, - on: Optional[Union[str, Iterable[str]]] = None, + on: Iterable[str] = None, self_columns: Union[Dict[str, Optional[str]], Set[str]] = {"*"}, other_columns: Union[Dict[str, Optional[str]], Set[str]] = {"*"}, ) -> "DataFrame": @@ -583,18 +583,18 @@ def join( ... age_rows, column_names=["name", "age"], db=db) >>> result = student.join( ... student, - ... on="age", - ... self_columns={"*"}, - ... other_columns={"name": "name_2"}) + ... on=["age"], + ... self_columns={"name": "name", "age": "age_1"}, + ... other_columns={"name": "name_2", "age": "age_2"}) >>> result ---------------------- - name | age | name_2 - -------+-----+-------- - alice | 18 | alice - bob | 19 | carol - bob | 19 | bob - carol | 19 | carol - carol | 19 | bob + age | name | name_2 + -----+-------+-------- + 18 | alice | alice + 19 | bob | carol + 19 | bob | bob + 19 | carol | carol + 19 | carol | bob ---------------------- (5 rows) """ @@ -631,13 +631,80 @@ def bind(t: DataFrame, columns: Union[Dict[str, Optional[str]], Set[str]]) -> Li ) # USING clause in SQL uses argument `on`. sql_using_clause = f"USING ({join_column_names})" if join_column_names is not None else "" - return DataFrame( + + if on is None: + return DataFrame( + f""" + SELECT {",".join(target_list)} + FROM {self._name} {how} JOIN {other_clause} {sql_on_clause} {sql_using_clause} + """, + parents=[self, other], + ) + + def bind_using( + t: DataFrame, + columns: Union[Dict[str, Optional[str]], Set[str]], + on: Iterable[str], + suffix: str, + ) -> List[str]: + target_list: List[str] = [] + for k in columns: + col: Column = t[k] + v = columns[k] if isinstance(columns, dict) else (k + suffix) if k in on else None + target_list.append(col._serialize() + (f' AS "{v}"' if v is not None else "")) + return target_list + + self_target_list = ( + bind_using(self, self_columns, on, "_l") + if isinstance(self_columns, set) + else bind(self, self_columns) + ) + other_target_list = ( + bind_using(other_temp, other_columns, on, "_r") + if isinstance(other_columns, set) + else bind(other_temp, other_columns) + ) + target_list = self_target_list + other_target_list + + join_dataframe = DataFrame( f""" SELECT {",".join(target_list)} FROM {self._name} {how} JOIN {other_clause} {sql_on_clause} {sql_using_clause} """, parents=[self, other], ) + coalesce_target_list = [] + if self_columns and other_columns: + for k in on: + s_v = self_columns[k] if isinstance(self_columns, dict) else (k + "_l") + o_v = other_columns[k] if isinstance(other_columns, dict) else (k + "_r") + coalesce_target_list.append(f"COALESCE({s_v},{o_v}) AS {k}") + + join_df = DataFrame( + f""" + SELECT * {("," + ",".join(coalesce_target_list)) if coalesce_target_list != [] else ""} + FROM {join_dataframe._name} + """, + parents=[join_dataframe], + ) + + self_columns_set = ( + self_columns + if isinstance(self_columns, set) + else set([k if k in on else v for k, v in self_columns.items()]) + ) + other_columns_set = ( + other_columns + if isinstance(other_columns, set) + else set([k if k in on else v for k, v in other_columns.items()]) + ) + return DataFrame( + f""" + SELECT {",".join(sorted(self_columns_set | other_columns_set))} + FROM {join_df._name} + """, + parents=[join_df], + ) inner_join = partialmethod(join, how="INNER") """ diff --git a/tests/test_join.py b/tests/test_join.py index f8ab5a50..7b31641e 100644 --- a/tests/test_join.py +++ b/tests/test_join.py @@ -103,8 +103,8 @@ def test_join_same_column_using(db: gp.Database): rows = [(1,), (2,), (3,)] t1 = db.create_dataframe(rows=rows, column_names=["id"]) t2 = db.create_dataframe(rows=rows, column_names=["id"]) - ret = t1.join(t2, on=["id"], self_columns={"id": "t1_id"}, other_columns={"id": "t2_id"}) - assert sorted(next(iter(ret)).keys()) == sorted(["t1_id", "t2_id"]) + ret = t1.join(t2, on=["id"], self_columns={"id"}, other_columns={"id"}) + assert list(next(iter(ret)).keys()) == ["id"] def test_join_same_column_names(db: gp.Database): @@ -121,21 +121,21 @@ def test_join_on_multi_columns(db: gp.Database): rows = [(1, 1), (2, 1), (3, 1)] t1 = db.create_dataframe(rows=rows, column_names=["id", "n"]) t2 = db.create_dataframe(rows=rows, column_names=["id", "n"]) - ret = t1.join(t2, on=["id", "n"], other_columns={}) - print(ret) + ret = t1.join(t2, on=["id", "n"], self_columns={"id", "n"}, other_columns={"id", "n"}) + assert sorted(next(iter(ret)).keys()) == sorted(["id", "n"]) def test_dataframe_inner_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.DataFrame): ret: gp.DataFrame = zoo_1.join( zoo_2, on=["animal"], - self_columns={"animal": "zoo1_animal", "id": "zoo1_id"}, - other_columns={"animal": "zoo2_animal", "id": "zoo2_id"}, + self_columns={"animal": "animal_l", "id": "id_zoo1"}, + other_columns={"animal": "animal_r", "id": "id_zoo2"}, ) assert len(list(ret)) == 2 + assert sorted(next(iter(ret)).keys()) == sorted(["animal", "id_zoo1", "id_zoo2"]) for row in ret: - assert row["zoo1_animal"] == row["zoo2_animal"] - assert row["zoo1_animal"] == "Lion" or row["zoo1_animal"] == "Tiger" + assert row["animal"] == "Lion" or row["animal"] == "Tiger" def test_dataframe_left_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.DataFrame): @@ -147,10 +147,7 @@ def test_dataframe_left_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Dat ) assert len(list(ret)) == 4 for row in ret: - if row["zoo1_animal"] == "Lion" or row["zoo1_animal"] == "Tiger": - assert row["zoo1_animal"] == row["zoo2_animal"] - else: - assert row["zoo2_animal"] is None + if row["animal"] != "Lion" and row["animal"] != "Tiger": assert row["zoo2_id"] is None @@ -163,10 +160,7 @@ def test_dataframe_right_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Da ) assert len(list(ret)) == 4 for row in ret: - if row["zoo2_animal"] == "Lion" or row["zoo2_animal"] == "Tiger": - assert row["zoo1_animal"] == row["zoo2_animal"] - else: - assert row["zoo1_animal"] is None + if row["animal"] != "Lion" and row["animal"] != "Tiger": assert row["zoo1_id"] is None @@ -179,17 +173,42 @@ def test_dataframe_full_join(db: gp.Database, zoo_1: gp.DataFrame, zoo_2: gp.Dat ) assert len(list(ret)) == 6 for row in ret: - if row["zoo2_animal"] == "Lion" or row["zoo2_animal"] == "Tiger": - assert row["zoo1_animal"] == row["zoo2_animal"] - else: - assert (row["zoo1_animal"] is None and row["zoo2_animal"] is not None) or ( - row["zoo1_animal"] is not None and row["zoo2_animal"] is None - ) + if row["animal"] != "Lion" and row["animal"] != "Tiger": assert (row["zoo1_id"] is None and row["zoo2_id"] is not None) or ( row["zoo1_id"] is not None and row["zoo2_id"] is None ) +def test_dataframe_full_join_with_empty(db: gp.Database): + # fmt: off + rows1 = [(1, 100,), (2, 200,), (3, 300,), (4, 400,)] + rows2 = [(3, 300, 3000,), (4, 400, 4000,), (5, 500, 5000,), (6, 600, 6000)] + # fmt: on + l_df = db.create_dataframe(rows=rows1, column_names=["a", "b"]) + r_df = db.create_dataframe(rows=rows2, column_names=["a", "b", "c"]) + ret = l_df.full_join( + r_df, + self_columns={"a", "b"}, + other_columns={"a", "b", "c"}, + on=["a", "b"], + ).order_by("a")[:] + assert len(list(ret)) == 6 + expected = ( + "----------------\n" + " a | b | c \n" + "---+-----+------\n" + " 1 | 100 | \n" + " 2 | 200 | \n" + " 3 | 300 | 3000 \n" + " 4 | 400 | 4000 \n" + " 5 | 500 | 5000 \n" + " 6 | 600 | 6000 \n" + "----------------\n" + "(6 rows)\n" + ) + assert str(ret) == expected + + def test_join_natural(db: gp.Database): # fmt: off rows1 = [("Smart Phone", 1,), ("Laptop", 2,), ("DataFramet", 3,)] @@ -202,8 +221,8 @@ def test_join_natural(db: gp.Database): ret = categories.join( products, on=["category_id"], - self_columns={"category_name", "category_id"}, - other_columns={"product_name"}, + self_columns={"category_id", "category_name"}, + other_columns={"category_id", "product_name"}, ) assert len(list(ret)) == 6 assert sorted(next(iter(ret)).keys()) == sorted( @@ -246,8 +265,6 @@ def test_dataframe_self_join(db: gp.Database, zoo_1: gp.DataFrame): other_columns={"animal": "zoo2_animal", "id": "zoo2_id"}, ) assert len(list(ret)) == 4 - for row in ret: - assert row["zoo1_animal"] == row["zoo2_animal"] def test_dataframe_self_join_cond(db: gp.Database, zoo_1: gp.DataFrame): @@ -271,20 +288,17 @@ def test_dataframe_join_save(db: gp.Database, zoo_1: gp.DataFrame): ) t_join.save_as( "dataframe_join", - column_names=["zoo1_animal", "zoo1_id", "zoo2_animal", "zoo2_id"], + column_names=["animal", "zoo1_id", "zoo2_id"], temp=True, ) t_join_reload = gp.DataFrame.from_table("dataframe_join", db=db) assert sorted(next(iter(t_join_reload)).keys()) == sorted( [ - "zoo1_animal", + "animal", "zoo1_id", - "zoo2_animal", "zoo2_id", ] ) - for row in t_join_reload: - assert row["zoo1_animal"] == row["zoo2_animal"] def test_dataframe_join_ine(db: gp.Database): @@ -307,11 +321,19 @@ def test_dataframe_multiple_self_join(db: gp.Database, zoo_1: gp.DataFrame): ) ret = t_join.join( zoo_1, - cond=lambda s, o: s["zoo1_animal"] == o["animal"], + on=["animal"], + self_columns={"animal", "zoo1_id", "zoo2_id"}, + other_columns={"animal", "id"}, ) assert len(list(ret)) == 4 - for row in ret: - assert row["zoo2_animal"] == row["animal"] + assert sorted(next(iter(ret)).keys()) == sorted( + [ + "animal", + "id", + "zoo1_id", + "zoo2_id", + ] + ) # This test case is to guarantee that the CTEs are generated in the reversed diff --git a/tests/test_schema.py b/tests/test_schema.py index 863a5307..89f1c43d 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -47,7 +47,8 @@ def test_schema_self_join_on(db: gp.Database, t: gp.DataFrame): ret: gp.DataFrame = t.join( t, on=["id"], - other_columns={"id": "id_1"}, + self_columns={"id"}, + other_columns={"id"}, ) assert len(list(ret)) == 10