Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Left/Right/Inner Join #223

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def merge_single_partitions(lhs_partition, rhs_partition):
# which is definitely not possible (java dependency, JVM start...)
lhs_partition = lhs_partition.assign(common=1)
rhs_partition = rhs_partition.assign(common=1)
merged_data = lhs_partition.merge(rhs_partition, on=["common"])
# Need to drop "common" here, otherwise metadata mismatches
merged_data = lhs_partition.merge(rhs_partition, on=["common"]).drop(
columns=["common"]
)

return merged_data

Expand Down Expand Up @@ -179,7 +182,20 @@ def merge_single_partitions(lhs_partition, rhs_partition):
)
logger.debug(f"Additionally applying filter {filter_condition}")
df = filter_or_scalar(df, filter_condition)
# make sure we recover any lost rows in case of left, right or outer joins
if join_type in ["left", "outer"]:
df = df.merge(
df_lhs_renamed, on=list(df_lhs_renamed.columns), how="right"
)
elif join_type in ["right", "outer"]:
df = df.merge(
df_rhs_renamed, on=list(df_rhs_renamed.columns), how="right"
)
dc = DataContainer(df, cc)
# Caveat: columns of int may be casted to float if NaN is introduced
# for unmatched rows. Since we don't know which column would be casted
# without triggering compute(), we have to either leave it alone, or
# forcibly cast all int to nullable int.

dc = self.fix_dtype_to_row_type(dc, rel.getRowType())
return dc
Expand Down
80 changes: 76 additions & 4 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,50 @@ def datetime_table():
)


@pytest.fixture
def user_table_lk():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love those tests, they are super cool because they seem like coming from a real use-case, which is absolute brilliant.

However, can we also have a very simple one with just like 3-4 lines and two columns (e.g. the one I used in my comments)? This makes debugging much easier than skimming though multiple lines which (because the columns are so wide) even span a lot of space in the editor. I can also take care of this if you want!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I realize that. I simplified the new tests.

# Link table identified by id and startdate
# Used for query with both equality and inequality conditions
out = pd.DataFrame(
[[0, 5, 11, 111], [1, 2, pd.NA, 112], [1, 4, 13, 113], [3, 1, 14, 114],],
columns=["id", "startdate", "lk_nullint", "lk_int"],
)
out["lk_nullint"] = out["lk_nullint"].astype("Int32")
return out


@pytest.fixture
def user_table_lk2(user_table_lk):
# Link table identified by startdate only
# Used for query with inequality conditions
out = pd.DataFrame(
[[2, pd.NA, 112], [4, 13, 113],], columns=["startdate", "lk_nullint", "lk_int"],
)
out["lk_nullint"] = out["lk_nullint"].astype("Int32")
return out


@pytest.fixture
def user_table_ts():
# A table of time-series data identified by dates
out = pd.DataFrame(
[[1, 21], [3, pd.NA], [7, 23],], columns=["dates", "ts_nullint"],
)
out["ts_nullint"] = out["ts_nullint"].astype("Int32")
return out


@pytest.fixture
def user_table_pn():
# A panel table identified by id and dates
out = pd.DataFrame(
[[0, 1, pd.NA], [1, 5, 32], [2, 1, 33],],
columns=["ids", "dates", "pn_nullint"],
)
out["pn_nullint"] = out["pn_nullint"].astype("Int32")
return out


@pytest.fixture()
def c(
df_simple,
Expand All @@ -97,6 +141,10 @@ def c(
user_table_nan,
string_table,
datetime_table,
user_table_lk,
user_table_lk2,
user_table_ts,
user_table_pn,
):
dfs = {
"df_simple": df_simple,
Expand All @@ -108,6 +156,10 @@ def c(
"user_table_nan": user_table_nan,
"string_table": string_table,
"datetime_table": datetime_table,
"user_table_lk": user_table_lk,
"user_table_lk2": user_table_lk2,
"user_table_ts": user_table_ts,
"user_table_pn": user_table_pn,
}

# Lazy import, otherwise the pytest framework has problems
Expand All @@ -134,7 +186,9 @@ def temporary_data_file():


@pytest.fixture()
def assert_query_gives_same_result(engine):
def assert_query_gives_same_result(
engine, user_table_lk, user_table_lk2, user_table_ts, user_table_pn,
):
np.random.seed(42)

df1 = dd.from_pandas(
Expand Down Expand Up @@ -191,12 +245,22 @@ def assert_query_gives_same_result(engine):
c.create_table("df1", df1)
c.create_table("df2", df2)
c.create_table("df3", df3)
c.create_table("user_table_ts", user_table_ts)
c.create_table("user_table_pn", user_table_pn)
c.create_table("user_table_lk", user_table_lk)
c.create_table("user_table_lk2", user_table_lk2)

df1.compute().to_sql("df1", engine, index=False, if_exists="replace")
df2.compute().to_sql("df2", engine, index=False, if_exists="replace")
df3.compute().to_sql("df3", engine, index=False, if_exists="replace")

def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
user_table_ts.to_sql("user_table_ts", engine, index=False, if_exists="replace")
user_table_pn.to_sql("user_table_pn", engine, index=False, if_exists="replace")
user_table_lk.to_sql("user_table_lk", engine, index=False, if_exists="replace")
user_table_lk2.to_sql("user_table_lk2", engine, index=False, if_exists="replace")

def _assert_query_gives_same_result(
query, sort_columns=None, force_dtype=None, check_dtype=False, **kwargs,
):
sql_result = pd.read_sql_query(query, engine)
dask_result = c.sql(query).compute()

Expand All @@ -211,7 +275,15 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
sql_result = sql_result.reset_index(drop=True)
dask_result = dask_result.reset_index(drop=True)

assert_frame_equal(sql_result, dask_result, check_dtype=False, **kwargs)
# Change dtypes
if force_dtype == "sql":
for col, dtype in sql_result.dtypes.iteritems():
dask_result[col] = dask_result[col].astype(dtype)
elif force_dtype == "dask":
for col, dtype in dask_result.dtypes.iteritems():
sql_result[col] = sql_result[col].astype(dtype)

assert_frame_equal(sql_result, dask_result, check_dtype=check_dtype, **kwargs)

return _assert_query_gives_same_result

Expand Down
146 changes: 146 additions & 0 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,149 @@ def test_join_literal(c):
df_expected = pd.DataFrame({"user_id": [], "b": [], "user_id0": [], "c": []})

assert_frame_equal(df.reset_index(), df_expected.reset_index(), check_dtype=False)


def test_join_lricomplex(c):
# ---------- Panel data (equality and inequality conditions)

# Correct answer
dfcorrpn = pd.DataFrame(
[
[0, 1, pd.NA, pd.NA, pd.NA, pd.NA],
[1, 5, 32, 2, pd.NA, 112],
[1, 5, 32, 4, 13, 113],
[2, 1, 33, pd.NA, pd.NA, pd.NA],
],
columns=["ids", "dates", "pn_nullint", "startdate", "lk_nullint", "lk_int",],
)
change_types = {
"pn_nullint": "Int32",
"lk_nullint": "Int32",
"startdate": "Int64",
"lk_int": "Int64",
}
for k, v in change_types.items():
dfcorrpn[k] = dfcorrpn[k].astype(v)

# Left Join
querypnl = """
select a.*, b.startdate, b.lk_nullint, b.lk_int
from user_table_pn a left join user_table_lk b
on a.ids=b.id and b.startdate<=a.dates
"""
dftestpnl = (
c.sql(querypnl)
.compute()
.sort_values(["ids", "dates", "startdate"])
.reset_index(drop=True)
)
assert_frame_equal(dftestpnl, dfcorrpn, check_dtype=False)

# Right Join
querypnr = """
select b.*, a.startdate, a.lk_nullint, a.lk_int
from user_table_lk a right join user_table_pn b
on b.ids=a.id and a.startdate<=b.dates
"""
dftestpnr = (
c.sql(querypnr)
.compute()
.sort_values(["ids", "dates", "startdate"])
.reset_index(drop=True)
)
assert_frame_equal(dftestpnr, dfcorrpn, check_dtype=False)

# Inner Join
querypni = """
select a.*, b.startdate, b.lk_nullint, b.lk_int
from user_table_pn a inner join user_table_lk b
on a.ids=b.id and b.startdate<=a.dates
"""
dftestpni = (
c.sql(querypni)
.compute()
.sort_values(["ids", "dates", "startdate"])
.reset_index(drop=True)
)
assert_frame_equal(
dftestpni,
dfcorrpn.dropna(subset=["startdate"])
.assign(
startdate=lambda x: x["startdate"].astype("int64"),
lk_int=lambda x: x["lk_int"].astype("int64"),
)
.reset_index(drop=True),
check_dtype=False,
)

# ---------- Time-series data (inequality condition only)

# # Correct answer
dfcorrts = pd.DataFrame(
[
[1, 21, pd.NA, pd.NA, pd.NA],
[3, pd.NA, 2, pd.NA, 112],
[7, 23, 2, pd.NA, 112],
[7, 23, 4, 13, 113],
],
columns=["dates", "ts_nullint", "startdate", "lk_nullint", "lk_int",],
)
change_types = {
"ts_nullint": "Int32",
"lk_nullint": "Int32",
"startdate": "Int64",
"lk_int": "Int64",
}
for k, v in change_types.items():
dfcorrts[k] = dfcorrts[k].astype(v)

# Left Join
querytsl = """
select a.*, b.startdate, b.lk_nullint, b.lk_int
from user_table_ts a left join user_table_lk2 b
on b.startdate<=a.dates
"""
dftesttsl = (
c.sql(querytsl)
.compute()
.sort_values(["dates", "startdate"])
.reset_index(drop=True)
)
assert_frame_equal(dftesttsl, dfcorrts, check_dtype=False)

# Right Join
querytsr = """
select b.*, a.startdate, a.lk_nullint, a.lk_int
from user_table_lk2 a right join user_table_ts b
on a.startdate<=b.dates
"""
dftesttsr = (
c.sql(querytsr)
.compute()
.sort_values(["dates", "startdate"])
.reset_index(drop=True)
)
assert_frame_equal(dftesttsr, dfcorrts, check_dtype=False)

# Inner Join
querytsi = """
select a.*, b.startdate, b.lk_nullint, b.lk_int
from user_table_ts a inner join user_table_lk2 b
on b.startdate<=a.dates
"""
dftesttsi = (
c.sql(querytsi)
.compute()
.sort_values(["dates", "startdate"])
.reset_index(drop=True)
)
assert_frame_equal(
dftesttsi,
dfcorrts.dropna(subset=["startdate"])
.assign(
startdate=lambda x: x["startdate"].astype("int64"),
lk_int=lambda x: x["lk_int"].astype("int64"),
)
.reset_index(drop=True),
check_dtype=False,
)
Loading