diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 2183482c71f..fa123d4b60b 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -247,6 +247,7 @@ cdef class DatabaseConnectionView: cdef describe_state(self) cdef encode_state(self) cdef decode_state(self, type_id, data) + cdef bint needs_commit_after_state_sync(self) cdef check_capabilities( self, diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index d7ec0df29ab..75eaf3c309b 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -941,6 +941,16 @@ cdef class DatabaseConnectionView: aliases, session_config, globals_, type_id, data ) + cdef bint needs_commit_after_state_sync(self): + return any( + tx_conf in self._config + for tx_conf in [ + "default_transaction_isolation", + "default_transaction_deferrable", + # default_transaction_access_mode is not yet a backend config + ] + ) + property txid: def __get__(self): return self._txid diff --git a/edb/server/pgcon/pgcon.pxd b/edb/server/pgcon/pgcon.pxd index 6b9cf16b0c1..61e5f65a65f 100644 --- a/edb/server/pgcon/pgcon.pxd +++ b/edb/server/pgcon/pgcon.pxd @@ -138,6 +138,7 @@ cdef class PGConnection: public object pinned_by object last_state + bint state_reset_needs_commit public object last_init_con_data str last_indirect_return @@ -165,6 +166,7 @@ cdef class PGConnection: object bind_datas, bytes state, ssize_t start, ssize_t end, int dbver, object parse_array, object query_prefix, + bint needs_commit_state, ) cdef _rewrite_copy_data( diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 43201ab7ed9..b48efedd1ee 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -261,7 +261,51 @@ cdef class PGConnection: # by the backend. self.aborted_with_error = None + # Session State Management + # ------------------------ + # Due to the fact that backend sessions are not pinned to frontend + # sessions (EdgeQL, SQL, etc.) out of transactions, we need to sync + # the backend state with the frontend state before executing queries. + # + # For performance reasons, we try to avoid syncing the state by + # remembering the last state we've synced (last_state), and prefer + # backend connection with the same state as the frontend. + # + # Syncing the state is done by resetting the session state as a whole, + # followed by applying the new state, so that we don't have to track + # individual config resets. Again for performance reasons, the state + # sync is usually applied in the same implicit transaction as the + # actual query in order to avoid extra round trips. + # + # Though, there are exceptions when we need to sync the state in a + # separate transaction by inserting a SYNC message before the actual + # query. This is because either that the query itself is a START + # TRANSACTION / non-transactional command and a few other cases (see + # _parse_execute() below), or the state change affects new transaction + # creation like changing the `default_transaction_isolation` or its + # siblings (see `needs_commit_state` parameters). In such cases, we + # remember the `last_state` immediately after we received the + # ReadyForQuery message caused by the SYNC above, if there are no + # errors happened during state sync. Otherwise, we only remember + # `last_state` after the implicit transaction ends successfully, when + # we're sure the state is synced permanently. + # + # The actual queries may also change the session state. Regardless of + # how we synced state previously, we always remember the `last_state` + # after successful executions (also after transactions without errors, + # implicit or explicit). + # + # Finally, resetting an existing session state that was positive in + # `needs_commit_state` also requires a commit, because the new state + # may not have `needs_commit_state`. To achieve this, we remember the + # previous `needs_commit_state` in `state_reset_needs_commit` and + # always insert a SYNC in the next state sync if it's True. Also, if + # the actual queries modified those `default_transaction_*` settings, + # we also need to set `state_reset_needs_commit` to True for the next + # state sync(reset). See `needs_commit_after_state_sync()` functions + # in dbview classes (EdgeQL and SQL). self.last_state = dbview.DEFAULT_STATE + self.state_reset_needs_commit = False cpdef set_stmt_cache_size(self, int maxsize): self.prep_stmts.resize(maxsize) @@ -590,6 +634,7 @@ cdef class PGConnection: object bind_datas, bytes state, ssize_t start, ssize_t end, int dbver, object parse_array, object query_prefix, + bint needs_commit_state, ): # parse_array is an array of booleans for output with the same size as # the query_unit_group, indicating if each unit is freshly parsed @@ -607,6 +652,8 @@ cdef class PGConnection: if state is not None and start == 0: self._build_apply_state_req(state, out) + if needs_commit_state or self.state_reset_needs_commit: + self.write_sync(out) # Build the parse_array first, closing statements if needed before # actually executing any command that may fail, in order to ensure @@ -716,13 +763,16 @@ cdef class PGConnection: finally: await self.after_command() - async def wait_for_state_resp(self, bytes state, bint state_sync): + async def wait_for_state_resp( + self, bytes state, bint state_sync, bint needs_commit_state + ): if state_sync: try: await self._parse_apply_state_resp(2 if state is None else 3) finally: await self.wait_for_sync() self.last_state = state + self.state_reset_needs_commit = needs_commit_state else: await self._parse_apply_state_resp(2 if state is None else 3) @@ -973,6 +1023,7 @@ cdef class PGConnection: tx_isolation, list param_data_types, bytes query_prefix, + bint needs_commit_state, ): cdef: WriteBuffer out @@ -1005,6 +1056,8 @@ cdef class PGConnection: or not query.is_transactional or query.run_and_rollback or tx_isolation is not None + or needs_commit_state + or self.state_reset_needs_commit ): # This query has START TRANSACTION or non-transactional command # like CREATE DATABASE in it. @@ -1194,7 +1247,8 @@ cdef class PGConnection: try: if state is not None: - await self.wait_for_state_resp(state, state_sync) + await self.wait_for_state_resp( + state, state_sync, needs_commit_state) if query.run_and_rollback or tx_isolation is not None: await self.wait_for_sync() @@ -1304,6 +1358,7 @@ cdef class PGConnection: bint use_pending_func_cache = 0, tx_isolation = None, query_prefix = None, + bint needs_commit_state = False, ): self.before_command() started_at = time.monotonic() @@ -1319,6 +1374,7 @@ cdef class PGConnection: tx_isolation, param_data_types, query_prefix or b'', + needs_commit_state, ) finally: metrics.backend_query_duration.observe( @@ -1493,6 +1549,8 @@ cdef class PGConnection: ) await self.wait_for_sync() self.last_state = state + self.state_reset_needs_commit = ( + dbv.needs_commit_after_state_sync()) finally: await self.after_command() @@ -1512,6 +1570,8 @@ cdef class PGConnection: ) await self.wait_for_sync() self.last_state = state + self.state_reset_needs_commit = ( + dbv.needs_commit_after_state_sync()) try: return await self._parse_sql_extended_query( actions, @@ -1522,6 +1582,8 @@ cdef class PGConnection: finally: if not dbv.in_tx(): self.last_state = dbv.serialize_state() + self.state_reset_needs_commit = ( + dbv.needs_commit_after_state_sync()) finally: await self.after_command() diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 9ec263f8c87..ba0b52f9feb 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -77,6 +77,7 @@ cdef class ExecutionGroup: object dbv, # can be DatabaseConnectionView or Database fe_conn: frontend.AbstractFrontendConnection = None, bytes state = None, + bint needs_commit_state = False, ): cdef int dbver @@ -94,9 +95,14 @@ cdef class ExecutionGroup: dbver, parse_array, None, # query_prefix + needs_commit_state, ) if state is not None: - await be_conn.wait_for_state_resp(state, state_sync=0) + await be_conn.wait_for_state_resp( + state, + state_sync=needs_commit_state, + needs_commit_state=needs_commit_state, + ) for i, unit in enumerate(self.group): ignore_data = unit.output_format == FMT_NONE rv = await be_conn.wait_for_command( @@ -238,11 +244,13 @@ async def execute( cdef: bytes state = None, orig_state = None WriteBuffer bound_args_buf + bint needs_commit_state = False query_unit = compiled.query_unit_group[0] if not dbv.in_tx(): orig_state = state = dbv.serialize_state() + needs_commit_state = dbv.needs_commit_after_state_sync() new_types = None server = dbv.server @@ -269,13 +277,19 @@ async def execute( close_frontend_conns=query_unit.drop_db_reset_connections, ) if query_unit.system_config: + # execute_system_config() always sync state in a separate tx, + # so we don't need to pass down the needs_commit_state here await execute_system_config(be_conn, dbv, query_unit, state) else: config_ops = query_unit.config_ops if query_unit.sql: if query_unit.user_schema: - await be_conn.parse_execute(query=query_unit, state=state) + await be_conn.parse_execute( + query=query_unit, + state=state, + needs_commit_state=needs_commit_state, + ) if query_unit.ddl_stmt_id is not None: ddl_ret = be_conn.load_last_ddl_return(query_unit) if ddl_ret and ddl_ret['new_types']: @@ -299,6 +313,7 @@ async def execute( param_data_types=data_types, use_prep_stmt=use_prep_stmt, state=state, + needs_commit_state=needs_commit_state, dbver=dbv.dbver, use_pending_func_cache=compiled.use_pending_func_cache, tx_isolation=tx_isolation, @@ -411,6 +426,8 @@ async def execute( # 1. An orphan ROLLBACK command without a paring start tx # 2. There was no SQL, so the state can't have been synced. be_conn.last_state = state + be_conn.state_reset_needs_commit = ( + dbv.needs_commit_after_state_sync()) if compiled.recompiled_cache: for req, qu_group in compiled.recompiled_cache: dbv.cache_compiled_query(req, qu_group) @@ -437,7 +454,7 @@ async def execute_script( object global_schema, roles WriteBuffer bind_data int dbver = dbv.dbver - bint parse + bint parse, needs_commit_state = False user_schema = extensions = ext_config_settings = cached_reflection = None feature_used_metrics = None @@ -451,6 +468,7 @@ async def execute_script( in_tx = dbv.in_tx() if not in_tx: orig_state = state = dbv.serialize_state() + needs_commit_state = dbv.needs_commit_after_state_sync() data = None @@ -503,10 +521,16 @@ async def execute_script( dbver, parse_array, query_prefix, + needs_commit_state, ) if idx == 0 and state is not None: - await conn.wait_for_state_resp(state, state_sync=0) + await conn.wait_for_state_resp( + state, + state_sync=needs_commit_state, + needs_commit_state=needs_commit_state, + ) + conn.state_reset_needs_commit = needs_commit_state # state is restored, clear orig_state so that we can # set conn.last_state correctly later orig_state = None @@ -622,6 +646,8 @@ async def execute_script( state = dbv.serialize_state() if state is not orig_state: conn.last_state = state + conn.state_reset_needs_commit = ( + dbv.needs_commit_after_state_sync()) elif updated_user_schema: dbv._in_tx_user_schema_pickle = user_schema diff --git a/edb/server/protocol/pg_ext.pxd b/edb/server/protocol/pg_ext.pxd index 357849c6307..56e9b8788ed 100644 --- a/edb/server/protocol/pg_ext.pxd +++ b/edb/server/protocol/pg_ext.pxd @@ -53,6 +53,7 @@ cdef class ConnectionView: cdef inline _reset_tx_state( self, bint chain_implicit, bint chain_explicit ) + cdef bint needs_commit_after_state_sync(self) cpdef inline close_portal_if_exists(self, str name) cpdef inline close_portal(self, str name) cdef inline find_portal(self, str name) diff --git a/edb/server/protocol/pg_ext.pyx b/edb/server/protocol/pg_ext.pyx index 0c1b71a78a2..cf49642a141 100644 --- a/edb/server/protocol/pg_ext.pyx +++ b/edb/server/protocol/pg_ext.pyx @@ -360,6 +360,16 @@ cdef class ConnectionView: self._session_state_db_cache = (self._settings, rv) return rv + cdef bint needs_commit_after_state_sync(self): + return any( + tx_conf in self._settings + for tx_conf in [ + "default_transaction_isolation", + "default_transaction_deferrable", + "default_transaction_read_only", + ] + ) + cdef class PgConnection(frontend.FrontendConnection): interface = "sql" diff --git a/tests/test_server_proto.py b/tests/test_server_proto.py index 38dc61580b0..0017339abef 100644 --- a/tests/test_server_proto.py +++ b/tests/test_server_proto.py @@ -2069,6 +2069,46 @@ async def test_server_proto_tx_22(self): self.assertEqual(await self.con.query_single('SELECT 42'), 42) + async def assert_tx_isolation_and_default( + self, expected: str, *, default: str | None = None, conn=None + ): + if conn is None: + conn = self.con + if default is None: + default = expected + self.assertEqual( + await conn.query_single(''' + select ( + sys::get_transaction_isolation(), + assert_single(cfg::Config) + .default_transaction_isolation, + ); + '''), + (expected, default), + ) + + async def assert_read_only_and_default( + self, reason: str, *, default: str = 'ReadOnly', conn=None + ): + if conn is None: + conn = self.con + self.assertEqual( + await conn.query_single( + 'select assert_single(cfg::Config)' + '.default_transaction_access_mode;', + ), + default, + ) + with self.assertRaisesRegex( + edgedb.TransactionError, + reason, + ): + await self.con.query(''' + INSERT Tmp { + tmp := 'aaa' + }; + ''') + async def test_server_proto_tx_23(self): # Test that default_transaction_isolation is respected @@ -2078,12 +2118,7 @@ async def test_server_proto_tx_23(self): ''') try: - self.assertEqual( - await self.con.query( - 'select sys::get_transaction_isolation();', - ), - ["RepeatableRead"], - ) + await self.assert_tx_isolation_and_default('RepeatableRead') finally: await self.con.query(''' CONFIGURE SESSION @@ -2099,15 +2134,10 @@ async def test_server_proto_tx_24(self): ''') try: - with self.assertRaisesRegex( - edgedb.TransactionError, - 'cannot execute.*RepeatableRead', - ): - await self.con.query(''' - INSERT Tmp { - tmp := 'aaa' - }; - ''') + await self.assert_read_only_and_default( + "cannot execute.*RepeatableRead", + default='ReadWrite', + ) finally: await self.con.query(''' CONFIGURE SESSION @@ -2131,15 +2161,10 @@ async def test_server_proto_tx_25(self): SET default_transaction_access_mode := 'ReadWrite'; ''') - with self.assertRaisesRegex( - edgedb.TransactionError, - 'cannot execute.*RepeatableRead', - ): - await self.con.query(''' - INSERT Tmp { - tmp := 'aaa' - }; - ''') + await self.assert_read_only_and_default( + "cannot execute.*RepeatableRead", + default='ReadWrite', + ) finally: await self.con.query(''' CONFIGURE SESSION @@ -2163,22 +2188,8 @@ async def test_server_proto_tx_26(self): ''') try: - self.assertEqual( - await self.con.query( - 'select sys::get_transaction_isolation();', - ), - ["Serializable"], - ) - - with self.assertRaisesRegex( - edgedb.TransactionError, - 'cannot execute.*ReadOnly', - ): - await self.con.query(''' - INSERT Tmp { - tmp := 'aaa' - }; - ''') + await self.assert_tx_isolation_and_default('Serializable') + await self.assert_read_only_and_default('cannot execute.*ReadOnly') finally: await self.con.query(''' CONFIGURE SESSION @@ -2201,12 +2212,7 @@ async def test_server_proto_tx_27(self): START TRANSACTION; ''') - self.assertEqual( - await self.con.query( - 'select sys::get_transaction_isolation();', - ), - ["RepeatableRead"], - ) + await self.assert_tx_isolation_and_default('RepeatableRead') finally: await self.con.query(f''' @@ -2233,15 +2239,9 @@ async def test_server_proto_tx_28(self): START TRANSACTION; ''') - with self.assertRaisesRegex( - edgedb.TransactionError, - 'read-only transaction'): - - await self.con.query(''' - INSERT Tmp { - tmp := 'aaa' - }; - ''') + await self.assert_read_only_and_default( + 'read-only transaction', default='ReadWrite' + ) finally: await self.con.query(f''' ROLLBACK; @@ -2267,22 +2267,8 @@ async def test_server_proto_tx_29(self): START TRANSACTION; ''') - self.assertEqual( - await self.con.query( - 'select sys::get_transaction_isolation();', - ), - ["Serializable"], - ) - - with self.assertRaisesRegex( - edgedb.TransactionError, - 'read-only transaction'): - - await self.con.query(''' - INSERT Tmp { - tmp := 'aaa' - }; - ''') + await self.assert_tx_isolation_and_default('Serializable') + await self.assert_read_only_and_default('read-only transaction') finally: await self.con.query(f''' ROLLBACK; @@ -2338,22 +2324,10 @@ async def test_server_proto_tx_31(self): START TRANSACTION ISOLATION REPEATABLE READ; ''') - self.assertEqual( - await self.con.query( - 'select sys::get_transaction_isolation();', - ), - ["RepeatableRead"], + await self.assert_tx_isolation_and_default( + 'RepeatableRead', default='Serializable' ) - - with self.assertRaisesRegex( - edgedb.TransactionError, - 'read-only transaction'): - - await self.con.query(''' - INSERT Tmp { - tmp := 'aaa' - }; - ''') + await self.assert_read_only_and_default('read-only transaction') finally: await self.con.query(f''' ROLLBACK; @@ -2367,6 +2341,94 @@ async def test_server_proto_tx_31(self): await self.con.query('SELECT 42'), [42]) + async def test_server_proto_tx_32(self): + # Test state sync across 2 frontend connections works fine + con2 = await self.connect(database=self.con.dbname) + try: + await con2.query(''' + CONFIGURE SESSION + SET default_transaction_isolation := 'RepeatableRead'; + ''') + await self.assert_tx_isolation_and_default( + 'RepeatableRead', conn=con2 + ) + + # Try a few times back and forth - this should be enough to hit + # the same backend connection. + for _ in range(5): + # Test state reset + await self.assert_tx_isolation_and_default('Serializable') + + # Test state sync + await self.assert_tx_isolation_and_default( + 'RepeatableRead', conn=con2 + ) + finally: + await con2.aclose() + + async def _test_with_sql_connection(self, test_func): + # Test state sync across EdgeQL / SQL interfaces + try: + import asyncpg + except ImportError: + self.skipTest("asyncpg is not installed") + + conn_args = self.get_connect_args(database=self.con.dbname) + scon = await asyncpg.connect( + host=conn_args['host'], + port=conn_args['port'], + user=conn_args['user'], + database=conn_args['database'], + password=conn_args['password'], + ssl='require' + ) + + try: + await test_func(scon) + finally: + await self.con.query(''' + CONFIGURE SESSION + RESET default_transaction_isolation; + ''') + await scon.close() + + async def test_server_proto_tx_33(self): + async def test(scon): + await scon.execute(''' + set default_transaction_isolation to 'repeatable read'; + ''') + for _ in range(5): + await self.assert_tx_isolation_and_default('Serializable') + self.assertEqual( + await scon.fetchval('show transaction_isolation'), + 'repeatable read', + ) + self.assertEqual( + await scon.fetchval('show default_transaction_isolation'), + 'repeatable read', + ) + + await self._test_with_sql_connection(test) + + async def test_server_proto_tx_34(self): + async def test(scon): + await self.con.query(''' + CONFIGURE SESSION + SET default_transaction_isolation := 'RepeatableRead'; + ''') + for _ in range(5): + await self.assert_tx_isolation_and_default('RepeatableRead') + self.assertEqual( + await scon.fetchval('show transaction_isolation'), + 'serializable', + ) + self.assertEqual( + await scon.fetchval('show default_transaction_isolation'), + 'serializable', + ) + + await self._test_with_sql_connection(test) + class TestServerProtoMigration(tb.QueryTestCase):