diff --git a/Federation.md b/Federation.md new file mode 100644 index 000000000..27c1039a8 --- /dev/null +++ b/Federation.md @@ -0,0 +1,2 @@ +# Enabling Query Federation + diff --git a/Justfile b/Justfile index 1872bb498..9c0d05962 100644 --- a/Justfile +++ b/Justfile @@ -34,8 +34,8 @@ bootstrap-python: cd connectorx-python && poetry install setup-java: - cd federated-query/rewriter && mvn package -Dmaven.test.skip=true - cp -f ./federated-query/rewriter/target/federated-rewriter-1.0-SNAPSHOT-jar-with-dependencies.jar connectorx-python/connectorx/dependencies/federated-rewriter.jar + cd $ACCIO_PATH/rewriter && mvn package -Dmaven.test.skip=true + cp -f $ACCIO_PATH/rewriter/target/accio-rewriter-1.0-SNAPSHOT-jar-with-dependencies.jar connectorx-python/connectorx/dependencies/federated-rewriter.jar setup-python: cd connectorx-python && poetry run maturin develop --release diff --git a/README.md b/README.md index 3e4595906..03d081a21 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,17 @@ The function will partition the query by **evenly** splitting the specified colu ConnectorX will assign one thread for each partition to load and write data in parallel. Currently, we support partitioning on **numerical** columns (**cannot contain NULL**) for **SPJA** queries. +**Experimental: We are now providing federated query support, you can write a single query to join tables from two or more databases!** +```python +import connectorx as cx +db1 = "postgresql://username1:password1@server1:port1/database1" +db2 = "postgresql://username2:password2@server2:port2/database2" +cx.read_sql({"db1": db1, "db2": db2}, "SELECT * FROM db1.nation n, db2.region r where n.n_regionkey = r.r_regionkey") +``` +By default, we pushdown all joins from the same data source. More details for setup and configuration can be found [here](https://github.com/sfu-db/connector-x/blob/main/Federation.md). + +Check out more detailed usage and examples [here](https://sfu-db.github.io/connector-x/api.html). A general introduction of the project can be found in this [blog post](https://towardsdatascience.com/connectorx-the-fastest-way-to-load-data-from-databases-a65d4d4062d5). + # Installation ```bash diff --git a/connectorx-cpp/src/lib.rs b/connectorx-cpp/src/lib.rs index 1147a970d..2169f139f 100644 --- a/connectorx-cpp/src/lib.rs +++ b/connectorx-cpp/src/lib.rs @@ -71,6 +71,7 @@ pub unsafe extern "C" fn free_plans(res: *const CXSlice) { pub unsafe extern "C" fn connectorx_rewrite( conn_list: *const CXSlice, query: *const c_char, + strategy: *const c_char, ) -> CXSlice { let mut db_map = HashMap::new(); let conn_slice = unsafe { std::slice::from_raw_parts((*conn_list).ptr, (*conn_list).len) }; @@ -117,16 +118,18 @@ pub unsafe extern "C" fn connectorx_rewrite( } let query_str = unsafe { CStr::from_ptr(query) }.to_str().unwrap(); + let strategy_str = unsafe { CStr::from_ptr(strategy) }.to_str().unwrap(); let j4rs_base = match env::var("CX_LIB_PATH") { Ok(val) => Some(val), Err(_) => None, }; // println!("j4rs_base: {:?}", j4rs_base); - let fed_plan: Vec = rewrite_sql(query_str, &db_map, j4rs_base.as_deref()) - .unwrap() - .into_iter() - .map(|p| p.into()) - .collect(); + let fed_plan: Vec = + rewrite_sql(query_str, &db_map, j4rs_base.as_deref(), strategy_str) + .unwrap() + .into_iter() + .map(|p| p.into()) + .collect(); CXSlice::<_>::new_from_vec(fed_plan) } diff --git a/connectorx-python/connectorx/__init__.py b/connectorx-python/connectorx/__init__.py index 1dd7178a5..b4240b1be 100644 --- a/connectorx-python/connectorx/__init__.py +++ b/connectorx-python/connectorx/__init__.py @@ -259,6 +259,7 @@ def read_sql( partition_range: tuple[int, int] | None = None, partition_num: int | None = None, index_col: str | None = None, + strategy: str | None = None, ) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table: """ Run the SQL query, download the data from database into a dataframe. @@ -282,6 +283,8 @@ def read_sql( how many partitions to generate. index_col the index column to set; only applicable for return type "pandas", "modin", "dask". + strategy + strategy of rewriting the federated query for join pushdown Examples ======== @@ -318,7 +321,7 @@ def read_sql( query = remove_ending_semicolon(query) - result = _read_sql2(query, conn) + result = _read_sql2(query, conn, strategy) df = reconstruct_arrow(result) if return_type == "pandas": df = df.to_pandas(date_as_object=False, split_blocks=False) diff --git a/connectorx-python/connectorx/dependencies/federated-rewriter.jar b/connectorx-python/connectorx/dependencies/federated-rewriter.jar deleted file mode 100644 index 569f31a99..000000000 Binary files a/connectorx-python/connectorx/dependencies/federated-rewriter.jar and /dev/null differ diff --git a/connectorx-python/connectorx/tests/test_federation.py b/connectorx-python/connectorx/tests/test_federation.py index 36203d78c..08ba6d935 100644 --- a/connectorx-python/connectorx/tests/test_federation.py +++ b/connectorx-python/connectorx/tests/test_federation.py @@ -20,8 +20,8 @@ def db2_url() -> str: @pytest.mark.skipif( - not (os.environ.get("DB1") and os.environ.get("DB2")), - reason="Do not test federated queries is set unless both `DB1` and `DB2` are set", + not (os.environ.get("DB1") and os.environ.get("DB2") and os.environ.get("FED_CONFIG_PATH")), + reason="Do not test federated queries is set unless both `FED_CONFIG_PATH`, `DB1` and `DB2` are set", ) def test_fed_spj(db1_url: str, db2_url: str) -> None: query = "SELECT T.test_int, T.test_bool, S.test_language FROM db1.test_table T INNER JOIN db2.test_str S ON T.test_int = S.id" @@ -41,8 +41,8 @@ def test_fed_spj(db1_url: str, db2_url: str) -> None: @pytest.mark.skipif( - not (os.environ.get("DB1") and os.environ.get("DB2")), - reason="Do not test federated queries is set unless both `DB1` and `DB2` are set", + not (os.environ.get("DB1") and os.environ.get("DB2") and os.environ.get("FED_CONFIG_PATH")), + reason="Do not test federated queries is set unless both `FED_CONFIG_PATH`, `DB1` and `DB2` are set", ) def test_fed_spja(db1_url: str, db2_url: str) -> None: query = "select test_bool, AVG(test_float) as avg_float, SUM(test_int) as sum_int from db1.test_table as a, db2.test_str as b where a.test_int = b.id AND test_nullint is not NULL GROUP BY test_bool ORDER BY sum_int" @@ -57,3 +57,23 @@ def test_fed_spja(db1_url: str, db2_url: str) -> None: ) df.sort_values(by="SUM_INT", inplace=True, ignore_index=True) assert_frame_equal(df, expected, check_names=True) + +@pytest.mark.skipif( + not (os.environ.get("DB1") and os.environ.get("DB2") and os.environ.get("FED_CONFIG_PATH")), + reason="Do not test federated queries is set unless both `FED_CONFIG_PATH`, `DB1` and `DB2` are set", +) +def test_fed_spj_benefit(db1_url: str, db2_url: str) -> None: + query = "SELECT T.test_int, T.test_bool, S.test_language FROM db1.test_table T INNER JOIN db2.test_str S ON T.test_int = S.id" + df = read_sql({"db1": db1_url, "db2": db2_url}, query, strategy="benefit") + expected = pd.DataFrame( + index=range(5), + data={ + "TEST_INT": pd.Series([0, 1, 2, 3, 4], dtype="int64"), + "TEST_BOOL": pd.Series([None, True, False, False, None], dtype="object"), + "TEST_LANGUAGE": pd.Series( + ["English", "中文", "日本語", "русский", "Emoji"], dtype="object" + ), + }, + ) + df.sort_values(by="TEST_INT", inplace=True, ignore_index=True) + assert_frame_equal(df, expected, check_names=True) \ No newline at end of file diff --git a/connectorx-python/src/lib.rs b/connectorx-python/src/lib.rs index 90a9356c3..4f27cc30d 100644 --- a/connectorx-python/src/lib.rs +++ b/connectorx-python/src/lib.rs @@ -68,6 +68,7 @@ pub fn read_sql2<'py>( py: Python<'py>, sql: &str, db_map: HashMap, + strategy: Option<&str>, ) -> PyResult> { let rbs = run( sql.to_string(), @@ -77,6 +78,7 @@ pub fn read_sql2<'py>( .unwrap_or(J4RS_BASE_PATH.to_string()) .as_str(), ), + strategy.unwrap_or("pushdown"), ) .map_err(|e| PyRuntimeError::new_err(format!("{}", e)))?; let ptrs = arrow::to_ptrs(rbs); diff --git a/connectorx/src/fed_dispatcher.rs b/connectorx/src/fed_dispatcher.rs index 875a9fa29..43a184953 100644 --- a/connectorx/src/fed_dispatcher.rs +++ b/connectorx/src/fed_dispatcher.rs @@ -14,6 +14,7 @@ pub fn run( sql: String, db_map: HashMap, j4rs_base: Option<&str>, + strategy: &str, ) -> Vec { debug!("federated input sql: {}", sql); let mut db_conn_map: HashMap = HashMap::new(); @@ -28,7 +29,7 @@ pub fn run( ), ); } - let fed_plan = rewrite_sql(sql.as_str(), &db_conn_map, j4rs_base)?; + let fed_plan = rewrite_sql(sql.as_str(), &db_conn_map, j4rs_base, strategy)?; debug!("fetch queries from remote"); let (sender, receiver) = channel(); diff --git a/connectorx/src/fed_rewriter.rs b/connectorx/src/fed_rewriter.rs index 9d717c3bd..e1ce123d7 100644 --- a/connectorx/src/fed_rewriter.rs +++ b/connectorx/src/fed_rewriter.rs @@ -83,6 +83,7 @@ fn create_sources( jvm: &Jvm, db_map: &HashMap, ) -> (Instance, Instance) { + debug!("Could not find environment variable `FED_CONFIG_PATH`, use manual configuration (c++ API only)!"); let mut db_config = vec![]; let db_manual = jvm.create_instance("java.util.HashMap", &[])?; @@ -126,12 +127,17 @@ fn create_sources( #[allow(dead_code)] #[throws(ConnectorXOutError)] -fn create_sources2(jvm: &Jvm, db_map: &HashMap) -> Instance { +fn create_sources2( + jvm: &Jvm, + db_map: &HashMap, +) -> (Instance, Instance) { + debug!("Found environment variable `FED_CONFIG_PATH`, use configurations!"); let mut dbs = vec![]; + let db_manual = jvm.create_instance("java.util.HashMap", &[])?; for db in db_map.keys() { dbs.push(String::from(db)); } - jvm.java_list("java.lang.String", dbs)? + (jvm.java_list("java.lang.String", dbs)?, db_manual) } #[throws(ConnectorXOutError)] @@ -139,16 +145,22 @@ pub fn rewrite_sql( sql: &str, db_map: &HashMap, j4rs_base: Option<&str>, + strategy: &str, ) -> Vec { let jvm = init_jvm(j4rs_base)?; debug!("init jvm successfully!"); let sql = InvocationArg::try_from(sql).unwrap(); - let (db_config, db_manual) = create_sources(&jvm, db_map)?; - let rewriter = jvm.create_instance("ai.dataprep.federated.FederatedQueryRewriter", &[])?; + let strategy = InvocationArg::try_from(strategy).unwrap(); + + let (db_config, db_manual) = match env::var("FED_CONFIG_PATH") { + Ok(_) => create_sources2(&jvm, db_map)?, + _ => create_sources(&jvm, db_map)?, + }; + let rewriter = jvm.create_instance("ai.dataprep.accio.FederatedQueryRewriter", &[])?; let db_config = InvocationArg::try_from(db_config).unwrap(); let db_manual = InvocationArg::try_from(db_manual).unwrap(); - let plan = jvm.invoke(&rewriter, "rewrite3", &[sql, db_config, db_manual])?; + let plan = jvm.invoke(&rewriter, "rewrite", &[sql, db_config, db_manual, strategy])?; let count = jvm.invoke(&plan, "getCount", &[])?; let count: i32 = jvm.to_rust(count)?;