Skip to content

Commit

Permalink
Merge pull request #576 from aiven/kmichel-deflake-walreceiver
Browse files Browse the repository at this point in the history
walreceiver: stop replication when stopping thread
  • Loading branch information
alanfranz authored Dec 21, 2022
2 parents 7d35801 + 99b67f5 commit 2588f4b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 44 deletions.
102 changes: 65 additions & 37 deletions pghoard/walreceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import logging
import os
import select
import threading
import time
from io import BytesIO
from queue import Empty, Queue
from typing import Optional

import psycopg2
import psycopg2.errors
Expand Down Expand Up @@ -50,6 +52,8 @@ def __init__(
self.conn = None
self.c = None
self.buffer = BytesIO()
self.initial_lsn: Optional[LSN] = None
self.initial_lsn_available = threading.Event()
self.latest_wal = None
self.latest_wal_start = None
self.latest_activity = datetime.datetime.utcnow()
Expand Down Expand Up @@ -117,6 +121,8 @@ def start_replication(self):
lsn = LSN(self.last_flushed_lsn, self.pg_version_server)
else:
lsn = lsn_from_sysinfo(identify_system, self.pg_version_server)
self.initial_lsn = lsn
self.initial_lsn_available.set()
lsn = str(lsn.walfile_start_lsn)
self.log.info("Starting replication from %r, timeline: %r with slot: %r", lsn, timeline, self.replication_slot)
if self.replication_slot:
Expand All @@ -127,6 +133,14 @@ def start_replication(self):
self.c.start_replication(start_lsn=lsn, timeline=timeline)
return timeline

def stop_replication(self) -> None:
if self.c is not None:
self.c.close()
self.c = None
if self.conn is not None:
self.conn.close()
self.conn = None

def switch_wal(self):
self.log.debug("Switching WAL from %r amount of data: %r", self.latest_wal, self.buffer.tell())

Expand Down Expand Up @@ -159,43 +173,57 @@ def switch_wal(self):

def run_safe(self):
self._init_cursor()
if self.replication_slot:
self.create_replication_slot()
timeline = self.start_replication()
while self.running:
wal_name = None
try:
msg = self.c.read_message()
except psycopg2.DatabaseError as ex:
self.log.exception("Unexpected exception in reading walreceiver msg")
self.metrics.unexpected_exception(ex, where="walreceiver_run")
time.sleep(1)
continue
self.log.debug("replication_msg: %r, buffer: %r/%r", msg, self.buffer.tell(), WAL_SEG_SIZE)
if msg:
self.latest_activity = datetime.datetime.utcnow()
lsn = LSN(msg.data_start, timeline_id=timeline, server_version=self.pg_version_server)
wal_name = lsn.walfile_name

if not self.latest_wal:
self.latest_wal_start = lsn.lsn
self.latest_wal = wal_name
self.buffer.write(msg.payload)

# TODO: Calculate end pos and transmit that?
msg.cursor.send_feedback(write_lsn=lsn.lsn)

if wal_name and self.latest_wal != wal_name or self.buffer.tell() >= WAL_SEG_SIZE:
self.switch_wal()
self.process_completed_segments()

if not msg:
timeout = KEEPALIVE_INTERVAL - (datetime.datetime.now() - self.c.io_timestamp).total_seconds()
with suppress(InterruptedError):
if not any(select.select([self.c], [], [], max(0, timeout))):
self.c.send_feedback() # timing out, send keepalive
# When we stop, process sent wals to update last_flush lsn.
self.process_completed_segments(block=True)
try:
if self.replication_slot:
self.create_replication_slot()
timeline = self.start_replication()
while self.running:
wal_name = None
try:
msg = self.c.read_message()
except psycopg2.DatabaseError as ex:
self.log.exception("Unexpected exception in reading walreceiver msg")
self.metrics.unexpected_exception(ex, where="walreceiver_run")
time.sleep(1)
continue
self.log.debug("replication_msg: %r, buffer: %r/%r", msg, self.buffer.tell(), WAL_SEG_SIZE)
if msg:
self.latest_activity = datetime.datetime.utcnow()
lsn = LSN(msg.data_start, timeline_id=timeline, server_version=self.pg_version_server)
wal_name = lsn.walfile_name

if self.buffer.tell() > 0 and self.buffer.tell() + len(msg.payload) > WAL_SEG_SIZE:
# If adding the payload would make the wal segment too large, switch the WAL
# now instead of adding the payload and having it written partly in the current
# wal segment and partly in the next one.
self.switch_wal()
self.process_completed_segments()

if not self.latest_wal:
self.latest_wal_start = lsn.lsn
self.latest_wal = wal_name
self.buffer.write(msg.payload)

# TODO: Calculate end pos and transmit that?
msg.cursor.send_feedback(write_lsn=lsn.lsn)

if wal_name and self.latest_wal != wal_name or self.buffer.tell() >= WAL_SEG_SIZE:
self.switch_wal()
self.process_completed_segments()

if not msg:
timeout = KEEPALIVE_INTERVAL - (datetime.datetime.now() - self.c.io_timestamp).total_seconds()
with suppress(InterruptedError):
if not any(select.select([self.c], [], [], max(0.0, timeout))):
self.c.send_feedback() # timing out, send keepalive
# Don't leave unfinished segments waiting for more than the KEEPALIVE_INTERVAL
if self.buffer.tell() > 0:
self.switch_wal()
self.process_completed_segments()
# When we stop, process sent wals to update last_flush lsn.
self.process_completed_segments(block=True)
finally:
self.stop_replication()

def process_completed_segments(self, *, block=False):
for wal_start, queue in self.callbacks.items():
Expand Down
14 changes: 7 additions & 7 deletions test/test_walreceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def test_walreceiver(self, db, pghoard_walreceiver, replication_slot):
else:
node["slot"] = replication_slot

# The transfer agent state will be used to check what
# was uploaded
# Before starting the walreceiver, get the current wal name.
wal_name = get_current_lsn(node).walfile_name
# Start streaming, force a wal rotation, and check the wal has been
# archived
# The transfer agent state will be used to check what was uploaded
# Start streaming
pghoard.start_walreceiver(pghoard.test_site, node, None)
# Get the initial wal name of the server
pghoard.walreceivers[pghoard.test_site].initial_lsn_available.wait()
wal_name = pghoard.walreceivers[pghoard.test_site].initial_lsn.walfile_name
# Force a wal rotation
switch_wal(conn)
# Check that we uploaded one file, and it is the right one.
wait_for_xlog(pghoard, 1)
Expand All @@ -70,7 +70,7 @@ def test_walreceiver(self, db, pghoard_walreceiver, replication_slot):
previous_wal_name = lsn.previous_walfile_start_lsn.walfile_name
pghoard.start_walreceiver(pghoard.test_site, node, last_flushed_lsn)
wait_for_xlog(pghoard, 4)
last_flushed_lsn = stop_walreceiver(pghoard)
stop_walreceiver(pghoard)
state = get_transfer_agent_upload_xlog_state(pghoard)
assert state.get("xlogs_since_basebackup") == 4
assert state.get("latest_filename") == previous_wal_name
Expand Down

0 comments on commit 2588f4b

Please sign in to comment.