From 8b40f2824885144ee379d7ece8517503172b77b6 Mon Sep 17 00:00:00 2001 From: Pekka Enberg Date: Sat, 10 Aug 2024 15:17:36 +0300 Subject: [PATCH] Optimize executescript() to use batching Refs #70 --- src/lib.rs | 23 +++++++++++++---------- tests/test_suite.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d89930c..c8ae14b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -205,16 +205,9 @@ impl Connection { } fn executescript(self_: PyRef<'_, Self>, script: String) -> PyResult<()> { - let statements = script.split(';'); - for statement in statements { - let statement = statement.trim(); - if !statement.is_empty() { - let cursor = Connection::cursor(&self_)?; - self_ - .rt - .block_on(async { execute(&cursor, statement.to_string(), None).await })?; - } - } + let _ = self_.rt.block_on(async { + self_.conn.execute_batch(&script).await + }).map_err(to_py_err); Ok(()) } @@ -272,6 +265,16 @@ impl Cursor { Ok(self_) } + fn executescript<'a>(self_: PyRef<'a, Self>, script: String) -> PyResult> { + self_ + .rt + .block_on(async { + self_.conn.execute_batch(&script).await + }) + .map_err(to_py_err)?; + Ok(self_) + } + #[getter] fn description(self_: PyRef<'_, Self>) -> PyResult> { let stmt = self_.stmt.borrow(); diff --git a/tests/test_suite.py b/tests/test_suite.py index 67f47b5..2b11d5c 100644 --- a/tests/test_suite.py +++ b/tests/test_suite.py @@ -88,6 +88,19 @@ def test_cursor_executemany(provider): res = cur.execute("SELECT * FROM users") assert [(1, 'alice@example.com'), (2, 'bob@example.com')] == res.fetchall() +@pytest.mark.parametrize("provider", ["libsql", "sqlite"]) +def test_cursor_executescript(provider): + conn = connect(provider, ":memory:") + cur = conn.cursor() + cur.executescript(""" + CREATE TABLE users (id INTEGER, email TEXT); + INSERT INTO users VALUES (1, 'alice@example.org'); + INSERT INTO users VALUES (2, 'bob@example.org'); + """) + res = cur.execute("SELECT * FROM users") + assert (1, 'alice@example.org') == res.fetchone() + assert (2, 'bob@example.org') == res.fetchone() + @pytest.mark.parametrize("provider", ["libsql", "sqlite"]) def test_lastrowid(provider): conn = connect(provider, ":memory:")