From d54b2a69f15e85e4f76a3624bea7275c3ce5394b Mon Sep 17 00:00:00 2001 From: Linxiao Francis Cong Date: Mon, 23 Aug 2021 10:11:17 -0400 Subject: [PATCH 1/8] Improve Left/Right/Inner Join --- dask_sql/physical/rel/logical/join.py | 143 ++++++++++++++++++++++++-- 1 file changed, 137 insertions(+), 6 deletions(-) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index b0c079f3c..a258448a3 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -5,6 +5,9 @@ from typing import List, Tuple import dask.dataframe as dd + +# Need pd.NA +import pandas as pd from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph @@ -92,12 +95,32 @@ def convert( # 4. dask can only merge on the same column names. # We therefore create new columns on purpose, which have a distinct name. assert len(lhs_on) == len(rhs_on) + # Add two columns (1,2,...) to keep track of observations in left and + # right tables. They must be at the end of the columns since + # _join_on_columns needs the relative order of columns (lhs_on and rhs_on) + # Only dask-supported functions are used (assign and cumsum) so that a + # compute() is not triggered. + df_lhs_renamed = df_lhs_renamed.assign(left_idx=1) + df_lhs_renamed = df_lhs_renamed.assign( + left_idx=df_lhs_renamed["left_idx"].cumsum() + ) + df_rhs_renamed = df_rhs_renamed.assign(right_idx=1) + df_rhs_renamed = df_rhs_renamed.assign( + right_idx=df_rhs_renamed["right_idx"].cumsum() + ) + if lhs_on: # 5. Now we can finally merge on these columns # The resulting dataframe will contain all (renamed) columns from the lhs and rhs # plus the added columns + # Need the indicator for left/right join df = self._join_on_columns( - df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, join_type, + df_lhs_renamed, + df_rhs_renamed, + lhs_on, + rhs_on, + join_type, + indicator=True, ) else: # 5. We are in the complex join case @@ -148,10 +171,28 @@ def merge_single_partitions(lhs_partition, rhs_partition): ResourceWarning, ) + # Add _merge to be consistent with the case lhs_on=True + df["_merge"] = "both" + df["_merge"] = df["_merge"].astype("category") + # Put newly added columns to the end + df = df[ + df.columns.drop("left_idx").insert( + df.columns.get_loc("right_idx") - 1, "left_idx" + ) + ] + + # Completely reset index to uniquely identify each row since there + # could be duplicates. (Yeah. It may be better to inform users that + # index will break. After all, it is expected to be broken since the + # number of rows changes. + df = df.assign(uniqid=1) + df = df.assign(uniqid=df["uniqid"].cumsum()).set_index("uniqid") + # 6. So the next step is to make sure # we have the correct column order (and to remove the temporary join columns) - correct_column_order = list(df_lhs_renamed.columns) + list( - df_rhs_renamed.columns + # Need to exclude temporary columns left_idx and right_idx + correct_column_order = list(df_lhs_renamed.columns.drop("left_idx")) + list( + df_rhs_renamed.columns.drop("right_idx") ) cc = ColumnContainer(df.columns).limit_to(correct_column_order) @@ -177,8 +218,91 @@ def merge_single_partitions(lhs_partition, rhs_partition): for rex in filter_condition ], ) - logger.debug(f"Additionally applying filter {filter_condition}") - df = filter_or_scalar(df, filter_condition) + # Three cases to deal with inequality conditions (left join as an example): + # Case 1 [eq_unmatched] (Not matched by equality): + # Left-only from equality join (_merge=='left_only') + # => Keep all + # Case 2 [ineq_unmatched] (Not matched by inequality): + # For unique left_idx, there are no True in filter_condition + # => Set values from right/left table to missing (NaN or NaT) + # => Keep 1 copy and drop duplicates over left_idx (there could + # be duplicates now due to equality match). + # Case 3 (Matched by inequality): + # For unique left_idx, there are 1 or more True in filter_condition + # => Keep obs with True in filter_condition + # This has to be added to df since partition will break the groupby + df["filter_condition"] = filter_condition + if join_type in ["left", "right"]: + # ----- Case 1 (Not matched by equality) + if join_type == "left": + # Flag obs unmatched in equality join + df["eq_unmatched"] = df["_merge"] == "left_only" + idx_varname = "left_idx" + other_varpre = "rhs_" + else: + # Flag obs unmatched in equality join + df["eq_unmatched"] = df["_merge"] == "right_only" + idx_varname = "right_idx" + other_varpre = "lhs_" + + # ----- Case 2 (Not matched by inequality) + + # Set NA (pd.NA) + # Flag obs not matched by inequality + df = df.merge( + (df.groupby(idx_varname)["filter_condition"].agg("sum") < 1) + .rename("ineq_unmatched") + .to_frame(), + left_on=idx_varname, + right_index=True, + how="left", + ) + # Assign pd.NA + for v in df.columns[df.columns.str.startswith(other_varpre)]: + df[v] = df[v].mask( + df["ineq_unmatched"] & (~df["eq_unmatched"]), pd.NA + ) + + # Drop duplicates + # Flag the first obs for each unique left_idx + # (or right_idx for right join) in order to remove duplicates + df = df.merge( + df[[idx_varname]] + .drop_duplicates() + .assign(first_elem=True) + .drop(columns=[idx_varname]), + left_index=True, + right_index=True, + how="left", + ) + df["first_elem"] = df["first_elem"].fillna(False) + + # ----- The full condition to keep observations + filter_condition_all = ( + df["filter_condition"] + | df["eq_unmatched"] + | (df["ineq_unmatched"] & df["first_elem"]) + ) + # Drop added temporary columns + df = df.drop( + columns=[ + "left_idx", + "right_idx", + "_merge", + "filter_condition", + "eq_unmatched", + "ineq_unmatched", + "first_elem", + ] + ) + elif join_type == "inner": + filter_condition_all = filter_condition + # TODO: Full Join + + logger.debug(f"Additionally applying filter {filter_condition_all}") + df = filter_or_scalar(df, filter_condition_all) + # Reset index (maybe notify users that dask-sql may break index) + df = df.reset_index(drop=True) dc = DataContainer(df, cc) dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) @@ -191,6 +315,7 @@ def _join_on_columns( lhs_on: List[str], rhs_on: List[str], join_type: str, + indicator: bool = False, ) -> dd.DataFrame: lhs_columns_to_add = { f"common_{i}": df_lhs_renamed.iloc[:, index] @@ -222,7 +347,13 @@ def _join_on_columns( df_rhs_with_tmp = df_rhs_renamed.assign(**rhs_columns_to_add) added_columns = list(lhs_columns_to_add.keys()) - df = dd.merge(df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type) + df = dd.merge( + df_lhs_with_tmp, + df_rhs_with_tmp, + on=added_columns, + how=join_type, + indicator=indicator, + ) return df From 43edc685c339a577c1460056621f520505ee6166 Mon Sep 17 00:00:00 2001 From: Linxiao Francis Cong Date: Mon, 23 Aug 2021 10:11:54 -0400 Subject: [PATCH 2/8] Add new datasets in tests --- tests/integration/fixtures.py | 90 +++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 4566d3690..78ed0314f 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -86,6 +86,88 @@ def datetime_table(): ) +@pytest.fixture +def user_table_lk(): + # Link table identified by id and date range (startdate and enddate) + # Used for query with both equality and inequality conditions + out = pd.DataFrame( + [ + [0, 0, 2, pd.NA, 110, "a1", 1.1, pd.Timestamp("2001-01-01")], + [0, 4, 6, pd.NA, 111, "a2", 1.2, pd.Timestamp("2001-02-01")], + [1, 2, 5, pd.NA, 112, "a3", np.nan, pd.Timestamp("2001-03-01")], + [1, 4, 6, 13, 113, "a4", np.nan, pd.Timestamp("2001-04-01")], + [3, 1, 2, 14, 114, "a5", np.nan, pd.NaT], + [3, 2, 3, 15, 115, "a6", 1.6, pd.NaT], + ], + columns=[ + "id", + "startdate", + "enddate", + "lk_nullint", + "lk_int", + "lk_str", + "lk_float", + "lk_date", + ], + ) + out["lk_nullint"] = out["lk_nullint"].astype("Int32") + out["lk_str"] = out["lk_str"].astype("string") + return out + + +@pytest.fixture +def user_table_lk2(user_table_lk): + # Link table identified by only date range (startdate and enddate) + # Used for query with inequality conditions + return user_table_lk.set_index("id").loc[1].reset_index(drop=True) + + +@pytest.fixture +def user_table_ts(): + # A table of time-series data identified by dates + out = pd.DataFrame( + [ + [3, pd.NA, 221, "b1", 2.1, pd.Timestamp("2002-01-01")], + [4, 22, 222, "b2", np.nan, pd.Timestamp("2002-02-01")], + [7, 23, 223, "b3", 2.3, pd.NaT], + ], + columns=["dates", "ts_nullint", "ts_int", "ts_str", "ts_float", "ts_date"], + ) + out["ts_nullint"] = out["ts_nullint"].astype("Int32") + out["ts_str"] = out["ts_str"].astype("string") + return out + + +@pytest.fixture +def user_table_pn(): + # A panel table identified by id and dates + out = pd.DataFrame( + [ + [0, 1, pd.NA, 331, "c1", 3.1, pd.Timestamp("2003-01-01")], + [0, 2, pd.NA, 332, "c2", 3.2, pd.Timestamp("2003-02-01")], + [0, 3, pd.NA, 333, "c3", 3.3, pd.Timestamp("2003-03-01")], + [1, 3, pd.NA, 334, "c4", np.nan, pd.Timestamp("2003-04-01")], + [1, 4, 35, 335, "c5", np.nan, pd.Timestamp("2003-05-01")], + [2, 1, 36, 336, "c6", np.nan, pd.Timestamp("2003-06-01")], + [2, 3, 37, 337, "c7", np.nan, pd.NaT], + [3, 2, 38, 338, "c8", 3.8, pd.NaT], + [3, 2, 39, 339, "c9", 3.9, pd.NaT], + ], + columns=[ + "ids", + "dates", + "pn_nullint", + "pn_int", + "pn_str", + "pn_float", + "pn_date", + ], + ) + out["pn_nullint"] = out["pn_nullint"].astype("Int32") + out["pn_str"] = out["pn_str"].astype("string") + return out + + @pytest.fixture() def c( df_simple, @@ -97,6 +179,10 @@ def c( user_table_nan, string_table, datetime_table, + user_table_lk, + user_table_lk2, + user_table_ts, + user_table_pn, ): dfs = { "df_simple": df_simple, @@ -108,6 +194,10 @@ def c( "user_table_nan": user_table_nan, "string_table": string_table, "datetime_table": datetime_table, + "user_table_lk": user_table_lk, + "user_table_lk2": user_table_lk2, + "user_table_ts": user_table_ts, + "user_table_pn": user_table_pn, } # Lazy import, otherwise the pytest framework has problems From c199fe76cd61751bfa5cfb788c39711b7da37613 Mon Sep 17 00:00:00 2001 From: Linxiao Francis Cong Date: Mon, 23 Aug 2021 10:12:19 -0400 Subject: [PATCH 3/8] Fix issue for new datasets --- tests/integration/test_show.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/test_show.py b/tests/integration/test_show.py index 2165699ca..76749ef53 100644 --- a/tests/integration/test_show.py +++ b/tests/integration/test_show.py @@ -35,6 +35,10 @@ def test_tables(c): "user_table_nan", "string_table", "datetime_table", + "user_table_lk", + "user_table_lk2", + "user_table_ts", + "user_table_pn", ] } ) From 29a201dd9fb8919a2ecce039d1f8399f8e8bd380 Mon Sep 17 00:00:00 2001 From: Linxiao Francis Cong Date: Mon, 23 Aug 2021 10:12:41 -0400 Subject: [PATCH 4/8] Add new test for complex join --- tests/integration/test_join.py | 355 +++++++++++++++++++++++++++++++++ 1 file changed, 355 insertions(+) diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 6437cde0f..3507ac590 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -184,3 +184,358 @@ def test_join_literal(c): df_expected = pd.DataFrame({"user_id": [], "b": [], "user_id0": [], "c": []}) assert_frame_equal(df.reset_index(), df_expected.reset_index(), check_dtype=False) + + +def test_join_lricomplex(c): + # ---------- Panel data (equality and inequality conditions) + + # Correct answer + dfcorrpn = pd.DataFrame( + [ + [ + 0, + 1, + pd.NA, + 331, + "c1", + 3.1, + pd.Timestamp("2003-01-01"), + 0, + 2, + pd.NA, + 110, + "a1", + 1.1, + pd.Timestamp("2001-01-01"), + ], + [ + 0, + 2, + pd.NA, + 332, + "c2", + 3.2, + pd.Timestamp("2003-02-01"), + 0, + 2, + pd.NA, + 110, + "a1", + 1.1, + pd.Timestamp("2001-01-01"), + ], + [ + 0, + 3, + pd.NA, + 333, + "c3", + 3.3, + pd.Timestamp("2003-03-01"), + pd.NA, + pd.NA, + pd.NA, + pd.NA, + np.nan, + np.nan, + pd.NaT, + ], + [ + 1, + 3, + pd.NA, + 334, + "c4", + np.nan, + pd.Timestamp("2003-04-01"), + 2, + 5, + pd.NA, + 112, + "a3", + np.nan, + pd.Timestamp("2001-03-01"), + ], + [ + 1, + 4, + 35, + 335, + "c5", + np.nan, + pd.Timestamp("2003-05-01"), + 2, + 5, + pd.NA, + 112, + "a3", + np.nan, + pd.Timestamp("2001-03-01"), + ], + [ + 1, + 4, + 35, + 335, + "c5", + np.nan, + pd.Timestamp("2003-05-01"), + 4, + 6, + 13, + 113, + "a4", + np.nan, + pd.Timestamp("2001-04-01"), + ], + [ + 2, + 1, + 36, + 336, + "c6", + np.nan, + pd.Timestamp("2003-06-01"), + pd.NA, + pd.NA, + pd.NA, + pd.NA, + np.nan, + np.nan, + pd.NaT, + ], + [ + 2, + 3, + 37, + 337, + "c7", + np.nan, + pd.NaT, + pd.NA, + pd.NA, + pd.NA, + pd.NA, + np.nan, + np.nan, + pd.NaT, + ], + [3, 2, 38, 338, "c8", 3.8, pd.NaT, 1, 2, 14, 114, "a5", np.nan, pd.NaT], + [3, 2, 39, 339, "c9", 3.9, pd.NaT, 1, 2, 14, 114, "a5", np.nan, pd.NaT], + [3, 2, 38, 338, "c8", 3.8, pd.NaT, 2, 3, 15, 115, "a6", 1.6, pd.NaT], + [3, 2, 39, 339, "c9", 3.9, pd.NaT, 2, 3, 15, 115, "a6", 1.6, pd.NaT], + ], + columns=[ + "ids", + "dates", + "pn_nullint", + "pn_int", + "pn_str", + "pn_float", + "pn_date", + "startdate", + "enddate", + "lk_nullint", + "lk_int", + "lk_str", + "lk_float", + "lk_date", + ], + ) + change_types = { + "pn_nullint": "Int32", + "lk_nullint": "Int32", + "startdate": "Int64", + "enddate": "Int64", + "lk_int": "Int64", + "pn_str": "string", + "lk_str": "string", + } + for k, v in change_types.items(): + dfcorrpn[k] = dfcorrpn[k].astype(v) + + # Left Join + querypnl = """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_pn a left join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + """ + dftestpnl = ( + c.sql(querypnl).compute().sort_values(["ids", "dates", "startdate", "enddate"]) + ) + assert_frame_equal( + dftestpnl.reset_index(drop=True), dfcorrpn.reset_index(drop=True) + ) + + # Right Join + querypnr = """ + select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, + a.lk_float, a.lk_date + from user_table_lk a right join user_table_pn b + on b.ids=a.id and a.startdate<=b.dates and b.dates<=a.enddate + """ + dftestpnr = ( + c.sql(querypnr).compute().sort_values(["ids", "dates", "startdate", "enddate"]) + ) + assert_frame_equal( + dftestpnr.reset_index(drop=True), dfcorrpn.reset_index(drop=True) + ) + + # Inner Join + querypni = """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_pn a inner join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + """ + dftestpni = ( + c.sql(querypni).compute().sort_values(["ids", "dates", "startdate", "enddate"]) + ) + assert_frame_equal( + dftestpni.reset_index(drop=True), + dfcorrpn.dropna(subset=["startdate"]) + .assign( + startdate=lambda x: x["startdate"].astype("int64"), + enddate=lambda x: x["enddate"].astype("int64"), + lk_int=lambda x: x["lk_int"].astype("int64"), + ) + .reset_index(drop=True), + ) + + # ---------- Time-series data (inequality condition only) + + # Correct answer + dfcorrts = pd.DataFrame( + [ + [ + 3, + pd.NA, + 221, + "b1", + 2.1, + pd.Timestamp("2002-01-01"), + 2, + 5, + pd.NA, + 112, + "a3", + np.nan, + pd.Timestamp("2001-03-01"), + ], + [ + 4, + 22, + 222, + "b2", + np.nan, + pd.Timestamp("2002-02-01"), + 2, + 5, + pd.NA, + 112, + "a3", + np.nan, + pd.Timestamp("2001-03-01"), + ], + [ + 4, + 22, + 222, + "b2", + np.nan, + pd.Timestamp("2002-02-01"), + 4, + 6, + 13, + 113, + "a4", + np.nan, + pd.Timestamp("2001-04-01"), + ], + [ + 7, + 23, + 223, + "b3", + 2.3, + pd.NaT, + pd.NA, + pd.NA, + pd.NA, + pd.NA, + np.nan, + np.nan, + pd.NaT, + ], + ], + columns=[ + "dates", + "ts_nullint", + "ts_int", + "ts_str", + "ts_float", + "ts_date", + "startdate", + "enddate", + "lk_nullint", + "lk_int", + "lk_str", + "lk_float", + "lk_date", + ], + ) + change_types = { + "ts_nullint": "Int32", + "lk_nullint": "Int32", + "startdate": "Int64", + "enddate": "Int64", + "lk_int": "Int64", + "lk_str": "string", + "ts_str": "string", + } + for k, v in change_types.items(): + dfcorrts[k] = dfcorrts[k].astype(v) + + # Left Join + querytsl = """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_ts a left join user_table_lk2 b + on b.startdate<=a.dates and a.dates<=b.enddate + """ + dftesttsl = c.sql(querytsl).compute().sort_values(["dates", "startdate", "enddate"]) + assert_frame_equal( + dftesttsl.reset_index(drop=True), dfcorrts.reset_index(drop=True) + ) + + # Right Join + querytsr = """ + select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, + a.lk_float, a.lk_date + from user_table_lk2 a right join user_table_ts b + on a.startdate<=b.dates and b.dates<=a.enddate + """ + dftesttsr = c.sql(querytsr).compute().sort_values(["dates", "startdate", "enddate"]) + assert_frame_equal( + dftesttsr.reset_index(drop=True), dfcorrts.reset_index(drop=True) + ) + + # Inner Join + querytsi = """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_ts a inner join user_table_lk2 b + on b.startdate<=a.dates and a.dates<=b.enddate + """ + dftesttsi = c.sql(querytsi).compute().sort_values(["dates", "startdate", "enddate"]) + assert_frame_equal( + dftesttsi.reset_index(drop=True), + dfcorrts.dropna(subset=["startdate"]) + .assign( + startdate=lambda x: x["startdate"].astype("int64"), + enddate=lambda x: x["enddate"].astype("int64"), + lk_int=lambda x: x["lk_int"].astype("int64"), + ) + .reset_index(drop=True), + ) From a7c1e30d0667acb51494a8cdac9195ad852c7609 Mon Sep 17 00:00:00 2001 From: Francis Cong Date: Wed, 25 Aug 2021 08:39:14 -0400 Subject: [PATCH 5/8] Add new test data in fixture `assert_query_gives_same_result` * Add new data * Allow for type conversion * Allow for specification of whether or not to check dtypes. --- tests/integration/fixtures.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 78ed0314f..e80e92e92 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -224,7 +224,9 @@ def temporary_data_file(): @pytest.fixture() -def assert_query_gives_same_result(engine): +def assert_query_gives_same_result( + engine, user_table_lk, user_table_lk2, user_table_ts, user_table_pn, +): np.random.seed(42) df1 = dd.from_pandas( @@ -281,12 +283,22 @@ def assert_query_gives_same_result(engine): c.create_table("df1", df1) c.create_table("df2", df2) c.create_table("df3", df3) + c.create_table("user_table_ts", user_table_ts) + c.create_table("user_table_pn", user_table_pn) + c.create_table("user_table_lk", user_table_lk) + c.create_table("user_table_lk2", user_table_lk2) df1.compute().to_sql("df1", engine, index=False, if_exists="replace") df2.compute().to_sql("df2", engine, index=False, if_exists="replace") df3.compute().to_sql("df3", engine, index=False, if_exists="replace") - - def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): + user_table_ts.to_sql("user_table_ts", engine, index=False, if_exists="replace") + user_table_pn.to_sql("user_table_pn", engine, index=False, if_exists="replace") + user_table_lk.to_sql("user_table_lk", engine, index=False, if_exists="replace") + user_table_lk2.to_sql("user_table_lk2", engine, index=False, if_exists="replace") + + def _assert_query_gives_same_result( + query, sort_columns=None, force_dtype=None, check_dtype=False, **kwargs, + ): sql_result = pd.read_sql_query(query, engine) dask_result = c.sql(query).compute() @@ -301,7 +313,15 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): sql_result = sql_result.reset_index(drop=True) dask_result = dask_result.reset_index(drop=True) - assert_frame_equal(sql_result, dask_result, check_dtype=False, **kwargs) + # Change dtypes + if force_dtype == "sql": + for col, dtype in sql_result.dtypes.iteritems(): + dask_result[col] = dask_result[col].astype(dtype) + elif force_dtype == "dask": + for col, dtype in dask_result.dtypes.iteritems(): + sql_result[col] = sql_result[col].astype(dtype) + + assert_frame_equal(sql_result, dask_result, check_dtype=check_dtype, **kwargs) return _assert_query_gives_same_result From 0fc00db7f784a8f75a2e9612542dc477b735e39d Mon Sep 17 00:00:00 2001 From: Francis Cong Date: Wed, 25 Aug 2021 08:42:15 -0400 Subject: [PATCH 6/8] Add test in Postgres * Add new tests in `test_join.py` in `test_postgres.py` as well. * Expose port 5432 for postgres container. * Change address to "localhost" --- tests/integration/test_postgres.py | 88 ++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/integration/test_postgres.py b/tests/integration/test_postgres.py index f1614d5ad..fc5f0924d 100644 --- a/tests/integration/test_postgres.py +++ b/tests/integration/test_postgres.py @@ -16,6 +16,7 @@ def engine(): remove=True, network="dask-sql", environment={"POSTGRES_HOST_AUTH_METHOD": "trust"}, + ports={"5432/tcp": "5432"}, ) try: @@ -32,6 +33,7 @@ def engine(): # get the address and create the connection postgres.reload() address = postgres.attrs["NetworkSettings"]["Networks"]["dask-sql"]["IPAddress"] + address = "localhost" port = 5432 engine = sqlalchemy.create_engine( @@ -126,6 +128,92 @@ def test_join(assert_query_gives_same_result): ) +def test_join_lricomplex( + assert_query_gives_same_result, + engine, + user_table_ts, + user_table_pn, + user_table_lk, + user_table_lk2, + c, +): + # ---------- Panel data + # Left Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_pn a left join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + """, + ["ids", "dates", "startdate", "enddate"], + force_dtype="dask", + check_dtype=True, + ) + # Right Join + assert_query_gives_same_result( + """ + select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, + a.lk_float, a.lk_date + from user_table_lk a right join user_table_pn b + on b.ids=a.id and a.startdate<=b.dates and b.dates<=a.enddate + """, + ["ids", "dates", "startdate", "enddate"], + force_dtype="dask", + check_dtype=True, + ) + # Inner Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_pn a inner join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + """, + ["ids", "dates", "startdate", "enddate"], + force_dtype="dask", + check_dtype=True, + ) + + # ---------- Time-series data + # Left Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_ts a left join user_table_lk2 b + on b.startdate<=a.dates and a.dates<=b.enddate + """, + ["dates", "startdate", "enddate"], + force_dtype="dask", + check_dtype=True, + ) + # Right Join + assert_query_gives_same_result( + """ + select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, + a.lk_float, a.lk_date + from user_table_lk2 a right join user_table_ts b + on a.startdate<=b.dates and b.dates<=a.enddate + """, + ["dates", "startdate", "enddate"], + force_dtype="dask", + check_dtype=True, + ) + # Inner Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, + b.lk_float, b.lk_date + from user_table_ts a inner join user_table_lk2 b + on b.startdate<=a.dates and a.dates<=b.enddate + """, + ["dates", "startdate", "enddate"], + force_dtype="dask", + check_dtype=True, + ) + + def test_sort(assert_query_gives_same_result): assert_query_gives_same_result( """ From e32023c2d35f595769e8c064682b7322ed278cd1 Mon Sep 17 00:00:00 2001 From: Francis Cong Date: Mon, 30 Aug 2021 10:09:03 -0400 Subject: [PATCH 7/8] Roll back to previous join.py and add merge * Roll back to previous `join.py` but add lines to merge unmatched columns. * Fix a bug in `merge_single_partitions` where the returned dataframe has an extra column `"common"` that triggers metadata mismatch in the added lines of merging. --- dask_sql/physical/rel/logical/join.py | 153 ++++---------------------- 1 file changed, 19 insertions(+), 134 deletions(-) diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index a258448a3..57037b782 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -5,9 +5,6 @@ from typing import List, Tuple import dask.dataframe as dd - -# Need pd.NA -import pandas as pd from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph @@ -95,32 +92,12 @@ def convert( # 4. dask can only merge on the same column names. # We therefore create new columns on purpose, which have a distinct name. assert len(lhs_on) == len(rhs_on) - # Add two columns (1,2,...) to keep track of observations in left and - # right tables. They must be at the end of the columns since - # _join_on_columns needs the relative order of columns (lhs_on and rhs_on) - # Only dask-supported functions are used (assign and cumsum) so that a - # compute() is not triggered. - df_lhs_renamed = df_lhs_renamed.assign(left_idx=1) - df_lhs_renamed = df_lhs_renamed.assign( - left_idx=df_lhs_renamed["left_idx"].cumsum() - ) - df_rhs_renamed = df_rhs_renamed.assign(right_idx=1) - df_rhs_renamed = df_rhs_renamed.assign( - right_idx=df_rhs_renamed["right_idx"].cumsum() - ) - if lhs_on: # 5. Now we can finally merge on these columns # The resulting dataframe will contain all (renamed) columns from the lhs and rhs # plus the added columns - # Need the indicator for left/right join df = self._join_on_columns( - df_lhs_renamed, - df_rhs_renamed, - lhs_on, - rhs_on, - join_type, - indicator=True, + df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, join_type, ) else: # 5. We are in the complex join case @@ -139,7 +116,10 @@ def merge_single_partitions(lhs_partition, rhs_partition): # which is definitely not possible (java dependency, JVM start...) lhs_partition = lhs_partition.assign(common=1) rhs_partition = rhs_partition.assign(common=1) - merged_data = lhs_partition.merge(rhs_partition, on=["common"]) + # Need to drop "common" here, otherwise metadata mismatches + merged_data = lhs_partition.merge(rhs_partition, on=["common"]).drop( + columns=["common"] + ) return merged_data @@ -171,28 +151,10 @@ def merge_single_partitions(lhs_partition, rhs_partition): ResourceWarning, ) - # Add _merge to be consistent with the case lhs_on=True - df["_merge"] = "both" - df["_merge"] = df["_merge"].astype("category") - # Put newly added columns to the end - df = df[ - df.columns.drop("left_idx").insert( - df.columns.get_loc("right_idx") - 1, "left_idx" - ) - ] - - # Completely reset index to uniquely identify each row since there - # could be duplicates. (Yeah. It may be better to inform users that - # index will break. After all, it is expected to be broken since the - # number of rows changes. - df = df.assign(uniqid=1) - df = df.assign(uniqid=df["uniqid"].cumsum()).set_index("uniqid") - # 6. So the next step is to make sure # we have the correct column order (and to remove the temporary join columns) - # Need to exclude temporary columns left_idx and right_idx - correct_column_order = list(df_lhs_renamed.columns.drop("left_idx")) + list( - df_rhs_renamed.columns.drop("right_idx") + correct_column_order = list(df_lhs_renamed.columns) + list( + df_rhs_renamed.columns ) cc = ColumnContainer(df.columns).limit_to(correct_column_order) @@ -218,92 +180,22 @@ def merge_single_partitions(lhs_partition, rhs_partition): for rex in filter_condition ], ) - # Three cases to deal with inequality conditions (left join as an example): - # Case 1 [eq_unmatched] (Not matched by equality): - # Left-only from equality join (_merge=='left_only') - # => Keep all - # Case 2 [ineq_unmatched] (Not matched by inequality): - # For unique left_idx, there are no True in filter_condition - # => Set values from right/left table to missing (NaN or NaT) - # => Keep 1 copy and drop duplicates over left_idx (there could - # be duplicates now due to equality match). - # Case 3 (Matched by inequality): - # For unique left_idx, there are 1 or more True in filter_condition - # => Keep obs with True in filter_condition - # This has to be added to df since partition will break the groupby - df["filter_condition"] = filter_condition - if join_type in ["left", "right"]: - # ----- Case 1 (Not matched by equality) - if join_type == "left": - # Flag obs unmatched in equality join - df["eq_unmatched"] = df["_merge"] == "left_only" - idx_varname = "left_idx" - other_varpre = "rhs_" - else: - # Flag obs unmatched in equality join - df["eq_unmatched"] = df["_merge"] == "right_only" - idx_varname = "right_idx" - other_varpre = "lhs_" - - # ----- Case 2 (Not matched by inequality) - - # Set NA (pd.NA) - # Flag obs not matched by inequality + logger.debug(f"Additionally applying filter {filter_condition}") + df = filter_or_scalar(df, filter_condition) + # make sure we recover any lost rows in case of left, right or outer joins + if join_type in ["left", "outer"]: df = df.merge( - (df.groupby(idx_varname)["filter_condition"].agg("sum") < 1) - .rename("ineq_unmatched") - .to_frame(), - left_on=idx_varname, - right_index=True, - how="left", + df_lhs_renamed, on=list(df_lhs_renamed.columns), how="right" ) - # Assign pd.NA - for v in df.columns[df.columns.str.startswith(other_varpre)]: - df[v] = df[v].mask( - df["ineq_unmatched"] & (~df["eq_unmatched"]), pd.NA - ) - - # Drop duplicates - # Flag the first obs for each unique left_idx - # (or right_idx for right join) in order to remove duplicates + elif join_type in ["right", "outer"]: df = df.merge( - df[[idx_varname]] - .drop_duplicates() - .assign(first_elem=True) - .drop(columns=[idx_varname]), - left_index=True, - right_index=True, - how="left", - ) - df["first_elem"] = df["first_elem"].fillna(False) - - # ----- The full condition to keep observations - filter_condition_all = ( - df["filter_condition"] - | df["eq_unmatched"] - | (df["ineq_unmatched"] & df["first_elem"]) - ) - # Drop added temporary columns - df = df.drop( - columns=[ - "left_idx", - "right_idx", - "_merge", - "filter_condition", - "eq_unmatched", - "ineq_unmatched", - "first_elem", - ] + df_rhs_renamed, on=list(df_rhs_renamed.columns), how="right" ) - elif join_type == "inner": - filter_condition_all = filter_condition - # TODO: Full Join - - logger.debug(f"Additionally applying filter {filter_condition_all}") - df = filter_or_scalar(df, filter_condition_all) - # Reset index (maybe notify users that dask-sql may break index) - df = df.reset_index(drop=True) dc = DataContainer(df, cc) + # Caveat: columns of int may be casted to float if NaN is introduced + # for unmatched rows. Since we don't know which column would be casted + # without triggering compute(), we have to either leave it alone, or + # forcibly cast all int to nullable int. dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc @@ -315,7 +207,6 @@ def _join_on_columns( lhs_on: List[str], rhs_on: List[str], join_type: str, - indicator: bool = False, ) -> dd.DataFrame: lhs_columns_to_add = { f"common_{i}": df_lhs_renamed.iloc[:, index] @@ -347,13 +238,7 @@ def _join_on_columns( df_rhs_with_tmp = df_rhs_renamed.assign(**rhs_columns_to_add) added_columns = list(lhs_columns_to_add.keys()) - df = dd.merge( - df_lhs_with_tmp, - df_rhs_with_tmp, - on=added_columns, - how=join_type, - indicator=indicator, - ) + df = dd.merge(df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type) return df From 39980fd40f49ddf3c1910c8e36e8a88bd78b82de Mon Sep 17 00:00:00 2001 From: Francis Cong Date: Mon, 30 Aug 2021 10:11:02 -0400 Subject: [PATCH 8/8] Simplify new tests added * Since int columns will be converted to float if there are unmatched rows, I use `check_dtype=False` but still use nullable int in the assumed correct table. --- tests/integration/fixtures.py | 62 ++---- tests/integration/test_join.py | 327 ++++++----------------------- tests/integration/test_postgres.py | 48 ++--- 3 files changed, 93 insertions(+), 344 deletions(-) diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index e80e92e92..fc1aefd66 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -88,53 +88,34 @@ def datetime_table(): @pytest.fixture def user_table_lk(): - # Link table identified by id and date range (startdate and enddate) + # Link table identified by id and startdate # Used for query with both equality and inequality conditions out = pd.DataFrame( - [ - [0, 0, 2, pd.NA, 110, "a1", 1.1, pd.Timestamp("2001-01-01")], - [0, 4, 6, pd.NA, 111, "a2", 1.2, pd.Timestamp("2001-02-01")], - [1, 2, 5, pd.NA, 112, "a3", np.nan, pd.Timestamp("2001-03-01")], - [1, 4, 6, 13, 113, "a4", np.nan, pd.Timestamp("2001-04-01")], - [3, 1, 2, 14, 114, "a5", np.nan, pd.NaT], - [3, 2, 3, 15, 115, "a6", 1.6, pd.NaT], - ], - columns=[ - "id", - "startdate", - "enddate", - "lk_nullint", - "lk_int", - "lk_str", - "lk_float", - "lk_date", - ], + [[0, 5, 11, 111], [1, 2, pd.NA, 112], [1, 4, 13, 113], [3, 1, 14, 114],], + columns=["id", "startdate", "lk_nullint", "lk_int"], ) out["lk_nullint"] = out["lk_nullint"].astype("Int32") - out["lk_str"] = out["lk_str"].astype("string") return out @pytest.fixture def user_table_lk2(user_table_lk): - # Link table identified by only date range (startdate and enddate) + # Link table identified by startdate only # Used for query with inequality conditions - return user_table_lk.set_index("id").loc[1].reset_index(drop=True) + out = pd.DataFrame( + [[2, pd.NA, 112], [4, 13, 113],], columns=["startdate", "lk_nullint", "lk_int"], + ) + out["lk_nullint"] = out["lk_nullint"].astype("Int32") + return out @pytest.fixture def user_table_ts(): # A table of time-series data identified by dates out = pd.DataFrame( - [ - [3, pd.NA, 221, "b1", 2.1, pd.Timestamp("2002-01-01")], - [4, 22, 222, "b2", np.nan, pd.Timestamp("2002-02-01")], - [7, 23, 223, "b3", 2.3, pd.NaT], - ], - columns=["dates", "ts_nullint", "ts_int", "ts_str", "ts_float", "ts_date"], + [[1, 21], [3, pd.NA], [7, 23],], columns=["dates", "ts_nullint"], ) out["ts_nullint"] = out["ts_nullint"].astype("Int32") - out["ts_str"] = out["ts_str"].astype("string") return out @@ -142,29 +123,10 @@ def user_table_ts(): def user_table_pn(): # A panel table identified by id and dates out = pd.DataFrame( - [ - [0, 1, pd.NA, 331, "c1", 3.1, pd.Timestamp("2003-01-01")], - [0, 2, pd.NA, 332, "c2", 3.2, pd.Timestamp("2003-02-01")], - [0, 3, pd.NA, 333, "c3", 3.3, pd.Timestamp("2003-03-01")], - [1, 3, pd.NA, 334, "c4", np.nan, pd.Timestamp("2003-04-01")], - [1, 4, 35, 335, "c5", np.nan, pd.Timestamp("2003-05-01")], - [2, 1, 36, 336, "c6", np.nan, pd.Timestamp("2003-06-01")], - [2, 3, 37, 337, "c7", np.nan, pd.NaT], - [3, 2, 38, 338, "c8", 3.8, pd.NaT], - [3, 2, 39, 339, "c9", 3.9, pd.NaT], - ], - columns=[ - "ids", - "dates", - "pn_nullint", - "pn_int", - "pn_str", - "pn_float", - "pn_date", - ], + [[0, 1, pd.NA], [1, 5, 32], [2, 1, 33],], + columns=["ids", "dates", "pn_nullint"], ) out["pn_nullint"] = out["pn_nullint"].astype("Int32") - out["pn_str"] = out["pn_str"].astype("string") return out diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 3507ac590..e5b142ce1 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -192,350 +192,141 @@ def test_join_lricomplex(c): # Correct answer dfcorrpn = pd.DataFrame( [ - [ - 0, - 1, - pd.NA, - 331, - "c1", - 3.1, - pd.Timestamp("2003-01-01"), - 0, - 2, - pd.NA, - 110, - "a1", - 1.1, - pd.Timestamp("2001-01-01"), - ], - [ - 0, - 2, - pd.NA, - 332, - "c2", - 3.2, - pd.Timestamp("2003-02-01"), - 0, - 2, - pd.NA, - 110, - "a1", - 1.1, - pd.Timestamp("2001-01-01"), - ], - [ - 0, - 3, - pd.NA, - 333, - "c3", - 3.3, - pd.Timestamp("2003-03-01"), - pd.NA, - pd.NA, - pd.NA, - pd.NA, - np.nan, - np.nan, - pd.NaT, - ], - [ - 1, - 3, - pd.NA, - 334, - "c4", - np.nan, - pd.Timestamp("2003-04-01"), - 2, - 5, - pd.NA, - 112, - "a3", - np.nan, - pd.Timestamp("2001-03-01"), - ], - [ - 1, - 4, - 35, - 335, - "c5", - np.nan, - pd.Timestamp("2003-05-01"), - 2, - 5, - pd.NA, - 112, - "a3", - np.nan, - pd.Timestamp("2001-03-01"), - ], - [ - 1, - 4, - 35, - 335, - "c5", - np.nan, - pd.Timestamp("2003-05-01"), - 4, - 6, - 13, - 113, - "a4", - np.nan, - pd.Timestamp("2001-04-01"), - ], - [ - 2, - 1, - 36, - 336, - "c6", - np.nan, - pd.Timestamp("2003-06-01"), - pd.NA, - pd.NA, - pd.NA, - pd.NA, - np.nan, - np.nan, - pd.NaT, - ], - [ - 2, - 3, - 37, - 337, - "c7", - np.nan, - pd.NaT, - pd.NA, - pd.NA, - pd.NA, - pd.NA, - np.nan, - np.nan, - pd.NaT, - ], - [3, 2, 38, 338, "c8", 3.8, pd.NaT, 1, 2, 14, 114, "a5", np.nan, pd.NaT], - [3, 2, 39, 339, "c9", 3.9, pd.NaT, 1, 2, 14, 114, "a5", np.nan, pd.NaT], - [3, 2, 38, 338, "c8", 3.8, pd.NaT, 2, 3, 15, 115, "a6", 1.6, pd.NaT], - [3, 2, 39, 339, "c9", 3.9, pd.NaT, 2, 3, 15, 115, "a6", 1.6, pd.NaT], - ], - columns=[ - "ids", - "dates", - "pn_nullint", - "pn_int", - "pn_str", - "pn_float", - "pn_date", - "startdate", - "enddate", - "lk_nullint", - "lk_int", - "lk_str", - "lk_float", - "lk_date", + [0, 1, pd.NA, pd.NA, pd.NA, pd.NA], + [1, 5, 32, 2, pd.NA, 112], + [1, 5, 32, 4, 13, 113], + [2, 1, 33, pd.NA, pd.NA, pd.NA], ], + columns=["ids", "dates", "pn_nullint", "startdate", "lk_nullint", "lk_int",], ) change_types = { "pn_nullint": "Int32", "lk_nullint": "Int32", "startdate": "Int64", - "enddate": "Int64", "lk_int": "Int64", - "pn_str": "string", - "lk_str": "string", } for k, v in change_types.items(): dfcorrpn[k] = dfcorrpn[k].astype(v) # Left Join querypnl = """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_pn a left join user_table_lk b - on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + on a.ids=b.id and b.startdate<=a.dates """ dftestpnl = ( - c.sql(querypnl).compute().sort_values(["ids", "dates", "startdate", "enddate"]) - ) - assert_frame_equal( - dftestpnl.reset_index(drop=True), dfcorrpn.reset_index(drop=True) + c.sql(querypnl) + .compute() + .sort_values(["ids", "dates", "startdate"]) + .reset_index(drop=True) ) + assert_frame_equal(dftestpnl, dfcorrpn, check_dtype=False) # Right Join querypnr = """ - select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, - a.lk_float, a.lk_date + select b.*, a.startdate, a.lk_nullint, a.lk_int from user_table_lk a right join user_table_pn b - on b.ids=a.id and a.startdate<=b.dates and b.dates<=a.enddate + on b.ids=a.id and a.startdate<=b.dates """ dftestpnr = ( - c.sql(querypnr).compute().sort_values(["ids", "dates", "startdate", "enddate"]) - ) - assert_frame_equal( - dftestpnr.reset_index(drop=True), dfcorrpn.reset_index(drop=True) + c.sql(querypnr) + .compute() + .sort_values(["ids", "dates", "startdate"]) + .reset_index(drop=True) ) + assert_frame_equal(dftestpnr, dfcorrpn, check_dtype=False) # Inner Join querypni = """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_pn a inner join user_table_lk b - on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + on a.ids=b.id and b.startdate<=a.dates """ dftestpni = ( - c.sql(querypni).compute().sort_values(["ids", "dates", "startdate", "enddate"]) + c.sql(querypni) + .compute() + .sort_values(["ids", "dates", "startdate"]) + .reset_index(drop=True) ) assert_frame_equal( - dftestpni.reset_index(drop=True), + dftestpni, dfcorrpn.dropna(subset=["startdate"]) .assign( startdate=lambda x: x["startdate"].astype("int64"), - enddate=lambda x: x["enddate"].astype("int64"), lk_int=lambda x: x["lk_int"].astype("int64"), ) .reset_index(drop=True), + check_dtype=False, ) # ---------- Time-series data (inequality condition only) - # Correct answer + # # Correct answer dfcorrts = pd.DataFrame( [ - [ - 3, - pd.NA, - 221, - "b1", - 2.1, - pd.Timestamp("2002-01-01"), - 2, - 5, - pd.NA, - 112, - "a3", - np.nan, - pd.Timestamp("2001-03-01"), - ], - [ - 4, - 22, - 222, - "b2", - np.nan, - pd.Timestamp("2002-02-01"), - 2, - 5, - pd.NA, - 112, - "a3", - np.nan, - pd.Timestamp("2001-03-01"), - ], - [ - 4, - 22, - 222, - "b2", - np.nan, - pd.Timestamp("2002-02-01"), - 4, - 6, - 13, - 113, - "a4", - np.nan, - pd.Timestamp("2001-04-01"), - ], - [ - 7, - 23, - 223, - "b3", - 2.3, - pd.NaT, - pd.NA, - pd.NA, - pd.NA, - pd.NA, - np.nan, - np.nan, - pd.NaT, - ], - ], - columns=[ - "dates", - "ts_nullint", - "ts_int", - "ts_str", - "ts_float", - "ts_date", - "startdate", - "enddate", - "lk_nullint", - "lk_int", - "lk_str", - "lk_float", - "lk_date", + [1, 21, pd.NA, pd.NA, pd.NA], + [3, pd.NA, 2, pd.NA, 112], + [7, 23, 2, pd.NA, 112], + [7, 23, 4, 13, 113], ], + columns=["dates", "ts_nullint", "startdate", "lk_nullint", "lk_int",], ) change_types = { "ts_nullint": "Int32", "lk_nullint": "Int32", "startdate": "Int64", - "enddate": "Int64", "lk_int": "Int64", - "lk_str": "string", - "ts_str": "string", } for k, v in change_types.items(): dfcorrts[k] = dfcorrts[k].astype(v) # Left Join querytsl = """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_ts a left join user_table_lk2 b - on b.startdate<=a.dates and a.dates<=b.enddate + on b.startdate<=a.dates """ - dftesttsl = c.sql(querytsl).compute().sort_values(["dates", "startdate", "enddate"]) - assert_frame_equal( - dftesttsl.reset_index(drop=True), dfcorrts.reset_index(drop=True) + dftesttsl = ( + c.sql(querytsl) + .compute() + .sort_values(["dates", "startdate"]) + .reset_index(drop=True) ) + assert_frame_equal(dftesttsl, dfcorrts, check_dtype=False) # Right Join querytsr = """ - select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, - a.lk_float, a.lk_date + select b.*, a.startdate, a.lk_nullint, a.lk_int from user_table_lk2 a right join user_table_ts b - on a.startdate<=b.dates and b.dates<=a.enddate + on a.startdate<=b.dates """ - dftesttsr = c.sql(querytsr).compute().sort_values(["dates", "startdate", "enddate"]) - assert_frame_equal( - dftesttsr.reset_index(drop=True), dfcorrts.reset_index(drop=True) + dftesttsr = ( + c.sql(querytsr) + .compute() + .sort_values(["dates", "startdate"]) + .reset_index(drop=True) ) + assert_frame_equal(dftesttsr, dfcorrts, check_dtype=False) # Inner Join querytsi = """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_ts a inner join user_table_lk2 b - on b.startdate<=a.dates and a.dates<=b.enddate + on b.startdate<=a.dates """ - dftesttsi = c.sql(querytsi).compute().sort_values(["dates", "startdate", "enddate"]) + dftesttsi = ( + c.sql(querytsi) + .compute() + .sort_values(["dates", "startdate"]) + .reset_index(drop=True) + ) assert_frame_equal( - dftesttsi.reset_index(drop=True), + dftesttsi, dfcorrts.dropna(subset=["startdate"]) .assign( startdate=lambda x: x["startdate"].astype("int64"), - enddate=lambda x: x["enddate"].astype("int64"), lk_int=lambda x: x["lk_int"].astype("int64"), ) .reset_index(drop=True), + check_dtype=False, ) diff --git a/tests/integration/test_postgres.py b/tests/integration/test_postgres.py index fc5f0924d..cd3689c09 100644 --- a/tests/integration/test_postgres.py +++ b/tests/integration/test_postgres.py @@ -10,13 +10,14 @@ def engine(): client = docker.from_env() network = client.networks.create("dask-sql", driver="bridge") + # For local test, you may need to add ports={"5432/tcp": "5432"} to expose port postgres = client.containers.run( "postgres:latest", detach=True, remove=True, network="dask-sql", environment={"POSTGRES_HOST_AUTH_METHOD": "trust"}, - ports={"5432/tcp": "5432"}, + # ports={"5432/tcp": "5432"}, ) try: @@ -33,7 +34,8 @@ def engine(): # get the address and create the connection postgres.reload() address = postgres.attrs["NetworkSettings"]["Networks"]["dask-sql"]["IPAddress"] - address = "localhost" + # For local test, you may need to assign address = "localhost" + # address = "localhost" port = 5432 engine = sqlalchemy.create_engine( @@ -141,36 +143,33 @@ def test_join_lricomplex( # Left Join assert_query_gives_same_result( """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_pn a left join user_table_lk b - on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + on a.ids=b.id and b.startdate<=a.dates """, - ["ids", "dates", "startdate", "enddate"], + ["ids", "dates", "startdate"], force_dtype="dask", check_dtype=True, ) # Right Join assert_query_gives_same_result( """ - select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, - a.lk_float, a.lk_date + select b.*, a.startdate, a.lk_nullint, a.lk_int from user_table_lk a right join user_table_pn b - on b.ids=a.id and a.startdate<=b.dates and b.dates<=a.enddate + on b.ids=a.id and a.startdate<=b.dates """, - ["ids", "dates", "startdate", "enddate"], + ["ids", "dates", "startdate"], force_dtype="dask", check_dtype=True, ) # Inner Join assert_query_gives_same_result( """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_pn a inner join user_table_lk b - on a.ids=b.id and b.startdate<=a.dates and a.dates<=b.enddate + on a.ids=b.id and b.startdate<=a.dates """, - ["ids", "dates", "startdate", "enddate"], + ["ids", "dates", "startdate"], force_dtype="dask", check_dtype=True, ) @@ -179,36 +178,33 @@ def test_join_lricomplex( # Left Join assert_query_gives_same_result( """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_ts a left join user_table_lk2 b - on b.startdate<=a.dates and a.dates<=b.enddate + on b.startdate<=a.dates """, - ["dates", "startdate", "enddate"], + ["dates", "startdate"], force_dtype="dask", check_dtype=True, ) # Right Join assert_query_gives_same_result( """ - select b.*, a.startdate, a.enddate, a.lk_nullint, a.lk_int, a.lk_str, - a.lk_float, a.lk_date + select b.*, a.startdate, a.lk_nullint, a.lk_int from user_table_lk2 a right join user_table_ts b - on a.startdate<=b.dates and b.dates<=a.enddate + on a.startdate<=b.dates """, - ["dates", "startdate", "enddate"], + ["dates", "startdate"], force_dtype="dask", check_dtype=True, ) # Inner Join assert_query_gives_same_result( """ - select a.*, b.startdate, b.enddate, b.lk_nullint, b.lk_int, b.lk_str, - b.lk_float, b.lk_date + select a.*, b.startdate, b.lk_nullint, b.lk_int from user_table_ts a inner join user_table_lk2 b - on b.startdate<=a.dates and a.dates<=b.enddate + on b.startdate<=a.dates """, - ["dates", "startdate", "enddate"], + ["dates", "startdate"], force_dtype="dask", check_dtype=True, )