Skip to content

Commit

Permalink
update federation related
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiaoying committed Oct 8, 2024
1 parent 11005c5 commit 8c53c58
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 18 deletions.
2 changes: 2 additions & 0 deletions Federation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Enabling Query Federation

4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ bootstrap-python:
cd connectorx-python && poetry install

setup-java:
cd federated-query/rewriter && mvn package -Dmaven.test.skip=true
cp -f ./federated-query/rewriter/target/federated-rewriter-1.0-SNAPSHOT-jar-with-dependencies.jar connectorx-python/connectorx/dependencies/federated-rewriter.jar
cd $ACCIO_PATH/rewriter && mvn package -Dmaven.test.skip=true
cp -f $ACCIO_PATH/rewriter/target/accio-rewriter-1.0-SNAPSHOT-jar-with-dependencies.jar connectorx-python/connectorx/dependencies/federated-rewriter.jar

setup-python:
cd connectorx-python && poetry run maturin develop --release
Expand Down
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ The function will partition the query by **evenly** splitting the specified colu
ConnectorX will assign one thread for each partition to load and write data in parallel.
Currently, we support partitioning on **numerical** columns (**cannot contain NULL**) for **SPJA** queries.

**Experimental: We are now providing federated query support, you can write a single query to join tables from two or more databases!**
```python
import connectorx as cx
db1 = "postgresql://username1:password1@server1:port1/database1"
db2 = "postgresql://username2:password2@server2:port2/database2"
cx.read_sql({"db1": db1, "db2": db2}, "SELECT * FROM db1.nation n, db2.region r where n.n_regionkey = r.r_regionkey")
```
By default, we pushdown all joins from the same data source. More details for setup and configuration can be found [here](https://github.com/sfu-db/connector-x/blob/main/Federation.md).

Check out more detailed usage and examples [here](https://sfu-db.github.io/connector-x/api.html). A general introduction of the project can be found in this [blog post](https://towardsdatascience.com/connectorx-the-fastest-way-to-load-data-from-databases-a65d4d4062d5).

# Installation

```bash
Expand Down
13 changes: 8 additions & 5 deletions connectorx-cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub unsafe extern "C" fn free_plans(res: *const CXSlice<CXFederatedPlan>) {
pub unsafe extern "C" fn connectorx_rewrite(
conn_list: *const CXSlice<CXConnectionInfo>,
query: *const c_char,
strategy: *const c_char,
) -> CXSlice<CXFederatedPlan> {
let mut db_map = HashMap::new();
let conn_slice = unsafe { std::slice::from_raw_parts((*conn_list).ptr, (*conn_list).len) };
Expand Down Expand Up @@ -117,16 +118,18 @@ pub unsafe extern "C" fn connectorx_rewrite(
}

let query_str = unsafe { CStr::from_ptr(query) }.to_str().unwrap();
let strategy_str = unsafe { CStr::from_ptr(strategy) }.to_str().unwrap();
let j4rs_base = match env::var("CX_LIB_PATH") {
Ok(val) => Some(val),
Err(_) => None,
};
// println!("j4rs_base: {:?}", j4rs_base);
let fed_plan: Vec<CXFederatedPlan> = rewrite_sql(query_str, &db_map, j4rs_base.as_deref())
.unwrap()
.into_iter()
.map(|p| p.into())
.collect();
let fed_plan: Vec<CXFederatedPlan> =
rewrite_sql(query_str, &db_map, j4rs_base.as_deref(), strategy_str)
.unwrap()
.into_iter()
.map(|p| p.into())
.collect();

CXSlice::<_>::new_from_vec(fed_plan)
}
Expand Down
5 changes: 4 additions & 1 deletion connectorx-python/connectorx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def read_sql(
partition_range: tuple[int, int] | None = None,
partition_num: int | None = None,
index_col: str | None = None,
strategy: str | None = None,
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
"""
Run the SQL query, download the data from database into a dataframe.
Expand All @@ -282,6 +283,8 @@ def read_sql(
how many partitions to generate.
index_col
the index column to set; only applicable for return type "pandas", "modin", "dask".
strategy
strategy of rewriting the federated query for join pushdown
Examples
========
Expand Down Expand Up @@ -318,7 +321,7 @@ def read_sql(

query = remove_ending_semicolon(query)

result = _read_sql2(query, conn)
result = _read_sql2(query, conn, strategy)
df = reconstruct_arrow(result)
if return_type == "pandas":
df = df.to_pandas(date_as_object=False, split_blocks=False)
Expand Down
Binary file not shown.
28 changes: 24 additions & 4 deletions connectorx-python/connectorx/tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def db2_url() -> str:


@pytest.mark.skipif(
not (os.environ.get("DB1") and os.environ.get("DB2")),
reason="Do not test federated queries is set unless both `DB1` and `DB2` are set",
not (os.environ.get("DB1") and os.environ.get("DB2") and os.environ.get("FED_CONFIG_PATH")),
reason="Do not test federated queries is set unless both `FED_CONFIG_PATH`, `DB1` and `DB2` are set",
)
def test_fed_spj(db1_url: str, db2_url: str) -> None:
query = "SELECT T.test_int, T.test_bool, S.test_language FROM db1.test_table T INNER JOIN db2.test_str S ON T.test_int = S.id"
Expand All @@ -41,8 +41,8 @@ def test_fed_spj(db1_url: str, db2_url: str) -> None:


@pytest.mark.skipif(
not (os.environ.get("DB1") and os.environ.get("DB2")),
reason="Do not test federated queries is set unless both `DB1` and `DB2` are set",
not (os.environ.get("DB1") and os.environ.get("DB2") and os.environ.get("FED_CONFIG_PATH")),
reason="Do not test federated queries is set unless both `FED_CONFIG_PATH`, `DB1` and `DB2` are set",
)
def test_fed_spja(db1_url: str, db2_url: str) -> None:
query = "select test_bool, AVG(test_float) as avg_float, SUM(test_int) as sum_int from db1.test_table as a, db2.test_str as b where a.test_int = b.id AND test_nullint is not NULL GROUP BY test_bool ORDER BY sum_int"
Expand All @@ -57,3 +57,23 @@ def test_fed_spja(db1_url: str, db2_url: str) -> None:
)
df.sort_values(by="SUM_INT", inplace=True, ignore_index=True)
assert_frame_equal(df, expected, check_names=True)

@pytest.mark.skipif(
not (os.environ.get("DB1") and os.environ.get("DB2") and os.environ.get("FED_CONFIG_PATH")),
reason="Do not test federated queries is set unless both `FED_CONFIG_PATH`, `DB1` and `DB2` are set",
)
def test_fed_spj_benefit(db1_url: str, db2_url: str) -> None:
query = "SELECT T.test_int, T.test_bool, S.test_language FROM db1.test_table T INNER JOIN db2.test_str S ON T.test_int = S.id"
df = read_sql({"db1": db1_url, "db2": db2_url}, query, strategy="benefit")
expected = pd.DataFrame(
index=range(5),
data={
"TEST_INT": pd.Series([0, 1, 2, 3, 4], dtype="int64"),
"TEST_BOOL": pd.Series([None, True, False, False, None], dtype="object"),
"TEST_LANGUAGE": pd.Series(
["English", "中文", "日本語", "русский", "Emoji"], dtype="object"
),
},
)
df.sort_values(by="TEST_INT", inplace=True, ignore_index=True)
assert_frame_equal(df, expected, check_names=True)
2 changes: 2 additions & 0 deletions connectorx-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub fn read_sql2<'py>(
py: Python<'py>,
sql: &str,
db_map: HashMap<String, String>,
strategy: Option<&str>,
) -> PyResult<Bound<'py, PyAny>> {
let rbs = run(
sql.to_string(),
Expand All @@ -77,6 +78,7 @@ pub fn read_sql2<'py>(
.unwrap_or(J4RS_BASE_PATH.to_string())
.as_str(),
),
strategy.unwrap_or("pushdown"),
)
.map_err(|e| PyRuntimeError::new_err(format!("{}", e)))?;
let ptrs = arrow::to_ptrs(rbs);
Expand Down
3 changes: 2 additions & 1 deletion connectorx/src/fed_dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub fn run(
sql: String,
db_map: HashMap<String, String>,
j4rs_base: Option<&str>,
strategy: &str,
) -> Vec<RecordBatch> {
debug!("federated input sql: {}", sql);
let mut db_conn_map: HashMap<String, FederatedDataSourceInfo> = HashMap::new();
Expand All @@ -28,7 +29,7 @@ pub fn run(
),
);
}
let fed_plan = rewrite_sql(sql.as_str(), &db_conn_map, j4rs_base)?;
let fed_plan = rewrite_sql(sql.as_str(), &db_conn_map, j4rs_base, strategy)?;

debug!("fetch queries from remote");
let (sender, receiver) = channel();
Expand Down
22 changes: 17 additions & 5 deletions connectorx/src/fed_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ fn create_sources(
jvm: &Jvm,
db_map: &HashMap<String, FederatedDataSourceInfo>,
) -> (Instance, Instance) {
debug!("Could not find environment variable `FED_CONFIG_PATH`, use manual configuration (c++ API only)!");
let mut db_config = vec![];
let db_manual = jvm.create_instance("java.util.HashMap", &[])?;

Expand Down Expand Up @@ -126,29 +127,40 @@ fn create_sources(

#[allow(dead_code)]
#[throws(ConnectorXOutError)]
fn create_sources2(jvm: &Jvm, db_map: &HashMap<String, FederatedDataSourceInfo>) -> Instance {
fn create_sources2(
jvm: &Jvm,
db_map: &HashMap<String, FederatedDataSourceInfo>,
) -> (Instance, Instance) {
debug!("Found environment variable `FED_CONFIG_PATH`, use configurations!");
let mut dbs = vec![];
let db_manual = jvm.create_instance("java.util.HashMap", &[])?;
for db in db_map.keys() {
dbs.push(String::from(db));
}
jvm.java_list("java.lang.String", dbs)?
(jvm.java_list("java.lang.String", dbs)?, db_manual)
}

#[throws(ConnectorXOutError)]
pub fn rewrite_sql(
sql: &str,
db_map: &HashMap<String, FederatedDataSourceInfo>,
j4rs_base: Option<&str>,
strategy: &str,
) -> Vec<Plan> {
let jvm = init_jvm(j4rs_base)?;
debug!("init jvm successfully!");

let sql = InvocationArg::try_from(sql).unwrap();
let (db_config, db_manual) = create_sources(&jvm, db_map)?;
let rewriter = jvm.create_instance("ai.dataprep.federated.FederatedQueryRewriter", &[])?;
let strategy = InvocationArg::try_from(strategy).unwrap();

let (db_config, db_manual) = match env::var("FED_CONFIG_PATH") {
Ok(_) => create_sources2(&jvm, db_map)?,
_ => create_sources(&jvm, db_map)?,
};
let rewriter = jvm.create_instance("ai.dataprep.accio.FederatedQueryRewriter", &[])?;
let db_config = InvocationArg::try_from(db_config).unwrap();
let db_manual = InvocationArg::try_from(db_manual).unwrap();
let plan = jvm.invoke(&rewriter, "rewrite3", &[sql, db_config, db_manual])?;
let plan = jvm.invoke(&rewriter, "rewrite", &[sql, db_config, db_manual, strategy])?;

let count = jvm.invoke(&plan, "getCount", &[])?;
let count: i32 = jvm.to_rust(count)?;
Expand Down

0 comments on commit 8c53c58

Please sign in to comment.