55from crate .client .exceptions import ProgrammingError
66from cr8 .run_crate import CrateNode
77
8- from crate .qa .tests import NodeProvider , insert_data , wait_for_active_shards , UpgradePath
8+ from crate .qa .tests import NodeProvider , insert_data , wait_for_active_shards , UpgradePath , assert_busy
99
1010ROLLING_UPGRADES_V4 = (
1111 # 4.0.0 -> 4.0.1 -> 4.0.2 don't support rolling upgrades due to a bug
@@ -106,6 +106,9 @@ def _test_rolling_upgrade(self, path: UpgradePath, nodes: int):
106106 with connect (remote_node .http_url , error_trace = True ) as remote_conn :
107107 new_shards = init_foreign_data_wrapper_data (conn , remote_conn , node .addresses .psql .port , remote_node .addresses .psql .port )
108108 expected_active_shards += new_shards
109+ if node .version >= (5 , 10 , 0 ):
110+ new_shards = init_logical_replication_data (self , conn , remote_conn , node .addresses .transport .port , remote_node .addresses .transport .port , expected_active_shards )
111+ expected_active_shards += new_shards
109112
110113 for idx , node in enumerate (cluster ):
111114 # Enforce an old version node be a handler to make sure that an upgraded node can serve 'select *' from an old version node.
@@ -140,6 +143,8 @@ def _test_rolling_upgrade(self, path: UpgradePath, nodes: int):
140143 assert remote_node is not None
141144 with connect (remote_node .http_url , error_trace = True ) as remote_conn :
142145 test_foreign_data_wrapper (self , conn , remote_conn )
146+ if node .version >= (5 , 10 , 0 ):
147+ test_logical_replication_queries (self , conn , remote_conn )
143148
144149 # Finally validate that all shards (primaries and replicas) of all partitions are started
145150 # and writes into the partitioned table while upgrading were successful
@@ -380,3 +385,57 @@ def test_foreign_data_wrapper(self, local_conn: Connection, remote_conn: Connect
380385 rc .execute ("refresh table doc.y" )
381386 c .execute ("select count(a) from doc.remote_y" )
382387 self .assertEqual (c .fetchall ()[0 ][0 ], count + 1 )
388+
389+
390+ def init_logical_replication_data (self , local_conn : Connection , remote_conn : Connection , local_transport_port :int , remote_transport_port : int , local_active_shards : int ) -> int :
391+ assert 4300 <= local_transport_port <= 4310 and 4300 <= remote_transport_port <= 4310
392+
393+ c = local_conn .cursor ()
394+ c .execute ("create table doc.x (a int) clustered into 1 shards with (number_of_replicas=0)" )
395+ c .execute ("create publication p for table doc.x" )
396+
397+ rc = remote_conn .cursor ()
398+ rc .execute ("create table doc.rx (a int) clustered into 1 shards with (number_of_replicas=0)" )
399+ rc .execute ("create publication rp for table doc.rx" )
400+
401+ rc .execute (f"create subscription rs connection 'crate://localhost:{ local_transport_port } ?user=crate&sslmode=sniff' publication p" )
402+ c .execute (f"create subscription s connection 'crate://localhost:{ remote_transport_port } ?user=crate&sslmode=sniff' publication rp" )
403+
404+ new_shards = 2 # 1 shard for doc.x and another 1 shard for doc.rx
405+ wait_for_active_shards (rc , new_shards )
406+ wait_for_active_shards (c , local_active_shards + new_shards )
407+ assert_busy (lambda : self .assertEqual (num_docs_x (rc ), 0 ))
408+ assert_busy (lambda : self .assertEqual (num_docs_rx (c ), 0 ))
409+
410+ return new_shards
411+
412+
413+ def test_logical_replication_queries (self , local_conn : Connection , remote_conn : Connection ):
414+ c = local_conn .cursor ()
415+ rc = remote_conn .cursor ()
416+
417+ # Cannot drop replicated tables
418+ with self .assertRaises (ProgrammingError ):
419+ rc .execute ("drop table doc.x" )
420+ c .execute ("drop table doc.rx" )
421+
422+ count = num_docs_x (rc )
423+ count2 = num_docs_rx (c )
424+
425+ c .execute ("insert into doc.x values (1)" )
426+ c .execute ("refresh table doc.x" )
427+ rc .execute ("insert into doc.rx values (1)" )
428+ rc .execute ("refresh table doc.rx" )
429+
430+ assert_busy (lambda : self .assertEqual (num_docs_x (rc ), count + 1 ))
431+ assert_busy (lambda : self .assertEqual (num_docs_rx (c ), count2 + 1 ))
432+
433+
434+ def num_docs_x (cursor ):
435+ cursor .execute ("select count(*) from doc.x" )
436+ return cursor .fetchall ()[0 ][0 ]
437+
438+
439+ def num_docs_rx (cursor ):
440+ cursor .execute ("select count(*) from doc.rx" )
441+ return cursor .fetchall ()[0 ][0 ]
0 commit comments