Skip to content

Commit

Permalink
fix pandas segfault on array and bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiaoying committed May 9, 2024
1 parent 4fd05d5 commit 1acc70a
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 147 deletions.
9 changes: 0 additions & 9 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ seed-db-more:
mysql --protocol tcp -h$MARIADB_HOST -P$MARIADB_PORT -u$MARIADB_USER -p$MARIADB_PASSWORD $MARIADB_DB < scripts/mysql.sql
trino $TRINO_URL --catalog=$TRINO_CATALOG < scripts/trino.sql

aaa2 conn="POSTGRES_URL":
cd connectorx-python && PYO3_PYTHON=$HOME/.pyenv/versions/3.8.12/bin/python3.8 PYTHONPATH=$HOME/.pyenv/versions/conn/lib/python3.8/site-packages LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.8.12/lib/ cargo run --no-default-features --features aaa --example test_aaa

aaa conn="POSTGRES_URL":
cd connectorx-python && PYO3_PYTHON=$HOME/.pyenv/versions/3.12.2/bin/python3.12 PYTHONPATH=$HOME/.pyenv/versions/conn12/lib/python3.12/site-packages LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.12.2/lib/ cargo run --no-default-features --features aaa --example test_aaa

aaa-d conn="POSTGRES_URL":
cd connectorx-python && RUST_BACKTRACE=1 PYO3_PYTHON=$HOME/.pyenv/versions/3.12.2/bin/python3.12 PYTHONPATH=$HOME/.pyenv/versions/conn12/lib/python3.12/site-packages LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.12.2/lib/ rust-lldb target/debug/examples/test_aaa

# benches
flame-tpch conn="POSTGRES_URL":
cd connectorx-python && PYO3_PYTHON=$HOME/.pyenv/versions/3.8.6/bin/python3.8 PYTHONPATH=$HOME/.pyenv/versions/conn/lib/python3.8/site-packages LD_LIBRARY_PATH=$HOME/.pyenv/versions/3.8.6/lib/ cargo run --no-default-features --features executable --features fptr --features nbstr --features dsts --features srcs --release --example flame_tpch {{conn}}
Expand Down
14 changes: 0 additions & 14 deletions connectorx-python/connectorx/tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,6 @@ def test_postgres_on_non_select(postgres_url: str) -> None:
query = "CREATE TABLE non_select(id INTEGER NOT NULL)"
df = read_sql(postgres_url, query)

def test_postgres_aaa(postgres_url: str) -> None:
# query = "SELECT test_int, test_str FROM test_table"
# # query = "SELECT test_bytea FROM test_types"
# # query = "SELECT test_boolarray FROM test_types"
# df = read_sql(postgres_url, query, partition_on="test_int", partition_num=2)

queries = [
"SELECT test_str FROM test_table WHERE test_int < 3",
"SELECT test_str FROM test_table WHERE test_int >= 3",
]

df = read_sql(postgres_url, query=queries)
print(df)

def test_postgres_aggregation(postgres_url: str) -> None:
query = "SELECT test_bool, SUM(test_float) FROM test_table GROUP BY test_bool"
df = read_sql(postgres_url, query)
Expand Down
31 changes: 0 additions & 31 deletions connectorx-python/examples/test_aaa.rs

This file was deleted.

4 changes: 2 additions & 2 deletions connectorx-python/src/pandas/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ where
}

/// Only fetch the metadata (header) of the destination.
pub fn get_meta(&mut self) -> Result<(), TP::Error> {
pub fn get_meta(mut self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TP::Error> {
let dorder = coordinate(S::DATA_ORDERS, PandasDestination::DATA_ORDERS)?;
self.src.set_data_order(dorder)?;
self.src.set_queries(self.queries.as_slice());
Expand All @@ -193,6 +193,6 @@ where
.collect::<CXResult<Vec<_>>>()?;
let names = self.src.names();
self.dst.allocate(0, &names, &dst_schema, dorder)?;
Ok(())
Ok(self.dst.result(py).unwrap())
}
}
141 changes: 58 additions & 83 deletions connectorx-python/src/pandas/get_meta.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{
destination::PandasDestination,
dispatcher::PandasDispatcher,
transports::{
BigQueryPandasTransport, MsSQLPandasTransport, MysqlPandasTransport, OraclePandasTransport,
PostgresPandasTransport, SqlitePandasTransport, TrinoPandasTransport,
Expand Down Expand Up @@ -38,7 +39,7 @@ pub fn get_meta<'py>(
query: String,
) -> Bound<'py, PyAny> {
let source_conn = SourceConn::try_from(conn)?;
let mut destination = PandasDestination::new();
let destination = PandasDestination::new();
let queries = &[CXQuery::Naked(query)];

match source_conn.ty {
Expand All @@ -49,104 +50,81 @@ pub fn get_meta<'py>(
("csv", Some(tls_conn)) => {
let sb =
PostgresSource::<CSVProtocol, MakeTlsConnector>::new(config, tls_conn, 1)?;
let mut dispatcher = Dispatcher::<
_,
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<CSVProtocol, MakeTlsConnector>,
>::new(
sb, &mut destination, queries, None
);
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
("csv", None) => {
let sb = PostgresSource::<CSVProtocol, NoTls>::new(config, NoTls, 1)?;
let mut dispatcher = Dispatcher::<
_,
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<CSVProtocol, NoTls>,
>::new(
sb, &mut destination, queries, None
);
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
("binary", Some(tls_conn)) => {
let sb = PostgresSource::<PgBinaryProtocol, MakeTlsConnector>::new(
config, tls_conn, 1,
)?;
let mut dispatcher =
Dispatcher::<
_,
_,
PostgresPandasTransport<PgBinaryProtocol, MakeTlsConnector>,
>::new(sb, &mut destination, queries, None);
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<PgBinaryProtocol, MakeTlsConnector>,
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
("binary", None) => {
let sb = PostgresSource::<PgBinaryProtocol, NoTls>::new(config, NoTls, 1)?;
let mut dispatcher = Dispatcher::<
_,
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<PgBinaryProtocol, NoTls>,
>::new(
sb, &mut destination, queries, None
);
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
("cursor", Some(tls_conn)) => {
let sb = PostgresSource::<CursorProtocol, MakeTlsConnector>::new(
config, tls_conn, 1,
)?;
let mut dispatcher = Dispatcher::<
_,
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<CursorProtocol, MakeTlsConnector>,
>::new(
sb, &mut destination, queries, None
);
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
("cursor", None) => {
let sb = PostgresSource::<CursorProtocol, NoTls>::new(config, NoTls, 1)?;
let mut dispatcher = Dispatcher::<
_,
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<CursorProtocol, NoTls>,
>::new(
sb, &mut destination, queries, None
);
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
("simple", Some(tls_conn)) => {
let sb = PostgresSource::<SimpleProtocol, MakeTlsConnector>::new(
config, tls_conn, 1,
)?;
let mut dispatcher = Dispatcher::<
_,
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<SimpleProtocol, MakeTlsConnector>,
>::new(
sb, &mut destination, queries, None
);
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
("simple", None) => {
let sb = PostgresSource::<SimpleProtocol, NoTls>::new(config, NoTls, 1)?;
let mut dispatcher = Dispatcher::<
_,
let dispatcher = PandasDispatcher::<
_,
PostgresPandasTransport<SimpleProtocol, NoTls>,
>::new(
sb, &mut destination, queries, None
);
>::new(sb, destination, queries, None);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
_ => unimplemented!("{} protocol not supported", protocol),
}
Expand All @@ -155,93 +133,90 @@ pub fn get_meta<'py>(
// remove the first "sqlite://" manually since url.path is not correct for windows
let path = &source_conn.conn.as_str()[9..];
let source = SQLiteSource::new(path, 1)?;
let mut dispatcher = Dispatcher::<_, _, SqlitePandasTransport>::new(
let dispatcher = PandasDispatcher::<_, SqlitePandasTransport>::new(
source,
&mut destination,
destination,
queries,
None,
);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
SourceType::MySQL => {
debug!("Protocol: {}", protocol);
match protocol {
"binary" => {
let source = MySQLSource::<MySQLBinaryProtocol>::new(&source_conn.conn[..], 1)?;
let mut dispatcher = Dispatcher::<
_,
_,
MysqlPandasTransport<MySQLBinaryProtocol>,
>::new(
source, &mut destination, queries, None
);
debug!("Running dispatcher");
dispatcher.get_meta()?;
}
"text" => {
let source = MySQLSource::<TextProtocol>::new(&source_conn.conn[..], 1)?;
let mut dispatcher =
Dispatcher::<_, _, MysqlPandasTransport<TextProtocol>>::new(
let dispatcher =
PandasDispatcher::<_, MysqlPandasTransport<MySQLBinaryProtocol>>::new(
source,
&mut destination,
destination,
queries,
None,
);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
"text" => {
let source = MySQLSource::<TextProtocol>::new(&source_conn.conn[..], 1)?;
let dispatcher = PandasDispatcher::<_, MysqlPandasTransport<TextProtocol>>::new(
source,
destination,
queries,
None,
);
debug!("Running dispatcher");
dispatcher.get_meta(py)?
}
_ => unimplemented!("{} protocol not supported", protocol),
}
}
SourceType::MsSQL => {
let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime"));
let source = MsSQLSource::new(rt, &source_conn.conn[..], 1)?;
let mut dispatcher = Dispatcher::<_, _, MsSQLPandasTransport>::new(
let dispatcher = PandasDispatcher::<_, MsSQLPandasTransport>::new(
source,
&mut destination,
destination,
queries,
None,
);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
SourceType::Oracle => {
let source = OracleSource::new(&source_conn.conn[..], 1)?;
let mut dispatcher = Dispatcher::<_, _, OraclePandasTransport>::new(
let dispatcher = PandasDispatcher::<_, OraclePandasTransport>::new(
source,
&mut destination,
destination,
queries,
None,
);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
SourceType::BigQuery => {
let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime"));
let source = BigQuerySource::new(rt, &source_conn.conn[..])?;
let mut dispatcher = Dispatcher::<_, _, BigQueryPandasTransport>::new(
let dispatcher = PandasDispatcher::<_, BigQueryPandasTransport>::new(
source,
&mut destination,
destination,
queries,
None,
);
debug!("Running dispatcher");
dispatcher.get_meta()?;
dispatcher.get_meta(py)?
}
SourceType::Trino => {
let rt = Arc::new(tokio::runtime::Runtime::new().expect("Failed to create runtime"));
let source = TrinoSource::new(rt, &source_conn.conn[..])?;
let dispatcher = Dispatcher::<_, _, TrinoPandasTransport>::new(
let dispatcher = PandasDispatcher::<_, TrinoPandasTransport>::new(
source,
&mut destination,
destination,
queries,
None,
);
dispatcher.run()?;
dispatcher.get_meta(py)?
}
_ => unimplemented!("{:?} not implemented!", source_conn.ty),
}

destination.result(py)?
}
7 changes: 3 additions & 4 deletions connectorx-python/src/pandas/pandas_columns/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,7 @@ where
let nvecs = self.lengths.len();

if nvecs > 0 {
let py = unsafe { Python::assume_gil_acquired() };

{
Python::with_gil(|py| -> Result<(), ConnectorXPythonError> {
// allocation in python is not thread safe
let _guard = GIL_MUTEX
.lock()
Expand All @@ -249,7 +247,8 @@ where
}
}
}
}
Ok(())
})?;

self.buffer.truncate(0);
self.lengths.truncate(0);
Expand Down
7 changes: 3 additions & 4 deletions connectorx-python/src/pandas/pandas_columns/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ impl BytesColumn {
let nstrings = self.bytes_lengths.len();

if nstrings > 0 {
let py = unsafe { Python::assume_gil_acquired() };

{
Python::with_gil(|py| -> Result<(), ConnectorXPythonError> {
// allocation in python is not thread safe
let _guard = GIL_MUTEX
.lock()
Expand All @@ -205,7 +203,8 @@ impl BytesColumn {
}
}
}
}
Ok(())
})?;

self.bytes_buf.truncate(0);
self.bytes_lengths.truncate(0);
Expand Down

0 comments on commit 1acc70a

Please sign in to comment.