Skip to content

Commit

Permalink
pg: reset seq num in replication after commit msg
Browse files Browse the repository at this point in the history
Untracked transaction are intended to have a sequence number of -1,
indicating there was no sequence update, which is used by transactions
from a pg.DB instance to correlate WAL data updates with a certain
transaction.

This commit fixes the replication monitor to work correctly with
concurrent untracked transactions by:
- reset the seq var to -1 in the captureRepl loop after each commit
- when a commit is reached with seq still -1 (not updated by an
  untracked transaction), do NOT send the commit ID on the channel

This also beefs up the logging and error handling in unexpected cases,
which corresponds to improper use of replMon methods or concurrent
(ab)use of the sentry table to break correct transaction sequencing.

Also, update the low-level replConn test: It only sends commit IDs for
sequenced transactions now because it needs to work with concurrent
transactions not created by a pg.DB instance.
  • Loading branch information
jchappelow authored Mar 11, 2024
1 parent e9411ce commit da52bba
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 16 deletions.
41 changes: 33 additions & 8 deletions internal/sql/pg/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ func startRepl(ctx context.Context, conn *pgconn.PgConn, publicationName, slotNa
// Launch the receiver goroutine, which will send commit digests and an
// error on return.
done := make(chan error, 1)

// WARNING: there must be a commitHash receiver for every send. This is
// coordinated by only sending commit IDs on this channel for transactions
// containing a sequence number update on the internal sentry table. This
// means: (1) there must only be one pg.DB instance per postgres database,
// and (2) other unsequenced writers such as a pg.Pool must not make updates
// to the sentry table that would cause a send with no receiver.
commitHash := make(chan []byte, 1)

go func() {
Expand Down Expand Up @@ -194,7 +201,7 @@ func captureRepl(ctx context.Context, conn *pgconn.PgConn, startLSN uint64,
return fmt.Errorf("decodeWALData failed: %w", err)
}
if anySeq != -1 {
seq = anySeq
seq = anySeq // the magic sentry table UPDATE that precedes commit
}

var lsnDelta uint64
Expand All @@ -208,21 +215,34 @@ func captureRepl(ctx context.Context, conn *pgconn.PgConn, startLSN uint64,

if commit {
cHash := hasher.Sum(nil)
hasher.Reset() // hasher = sha256.New()

// Only send the commit ID on the commitHash channel if this was
// a tracked commit, which includes a sequence number update on
// the internal sentry table that indicates it was created by
// the pg.DB type.
if seq == -1 {
logger.Debugf("Commit hash %x (unsequenced / untracked) LSN %v (%d) delta %d",
cHash, xld.WALStart, xld.WALStart, lsnDelta)
stats.reset()
break // switch => continue loop
}

cid := binary.BigEndian.AppendUint64(nil, uint64(seq))
cid = append(cid, cHash...)
select {
case commitHash <- cid:
default: // don't block if the receiver has choked
return fmt.Errorf("commit hash channel full")
}
hasher.Reset() // hasher = sha256.New()

logger.Infof("Commit hash %x, seq %d, LSN %v (%d) delta %d",
cHash, seq, xld.WALStart, xld.WALStart, lsnDelta)

logger.Debug("wal commit stats", log.Uint("inserts", stats.inserts), log.Uint("updates", stats.updates),
log.Uint("deletes", stats.deletes), log.Uint("truncates", stats.truncs))
stats.reset()

seq = -1 // next commit may be untracked, forget this one
}

default:
Expand All @@ -242,6 +262,9 @@ func (ws *walStats) reset() {
*ws = walStats{}
}

// decodeWALData decodes a wal data message given known relations, returning
// true if it was a commit message, or a non-negative seq value if it was a
// special update message on the internal sentry table
func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglogrepl.RelationMessageV2,
inStream *bool, stats *walStats, okSchema func(schema string) bool) (bool, int64, error) {
logicalMsg, err := parseV3(walData, *inStream)
Expand All @@ -265,9 +288,9 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog
// from rolled back transactions.

case *pglogrepl.CommitMessage:
logger.Debugf(" [msg] Commit: Commit LSN %v (%d), End LSN %v (%d)",
logger.Debugf(" [msg] Commit: Commit LSN %v (%d), End LSN %v (%d), seq = %d",
logicalMsg.CommitLSN, uint64(logicalMsg.CommitLSN),
logicalMsg.TransactionEndLSN, uint64(logicalMsg.TransactionEndLSN))
logicalMsg.TransactionEndLSN, uint64(logicalMsg.TransactionEndLSN), seq)

done = true

Expand Down Expand Up @@ -295,7 +318,7 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog
case *pglogrepl.UpdateMessageV2:
rel, ok := relations[logicalMsg.RelationID]
if !ok {
return false, 0, fmt.Errorf("insert: unknown relation ID %d", logicalMsg.RelationID)
return false, 0, fmt.Errorf("update: unknown relation ID %d", logicalMsg.RelationID)
}

// capture the seq value, before target schema filter
Expand All @@ -304,9 +327,11 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog
if len(cols) != 1 {
logger.Warnf("not one column in sentry table update (%d)", len(cols))
} else {
seq, err = cols[0].Int64()
newSeq, err := cols[0].Int64()
if err != nil {
logger.Warnf("invalid sequence number in sentry table update: %v", err)
} else {
seq = newSeq
}
}
}
Expand All @@ -333,7 +358,7 @@ func decodeWALData(hasher hash.Hash, walData []byte, relations map[uint32]*pglog
case *pglogrepl.DeleteMessageV2:
rel, ok := relations[logicalMsg.RelationID]
if !ok {
return false, 0, fmt.Errorf("insert: unknown relation ID %d", logicalMsg.RelationID)
return false, 0, fmt.Errorf("delete: unknown relation ID %d", logicalMsg.RelationID)
}

relName := rel.Namespace + "." + rel.RelationName
Expand Down
27 changes: 20 additions & 7 deletions internal/sql/pg/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ func Test_repl(t *testing.T) {

ctx, cancel := context.WithDeadline(ctx, deadline.Add(-time.Second*5))
defer cancel()
connQ, err := pgx.Connect(ctx, connString(host, port, user, pass, dbName, false))
if err != nil {
t.Fatal(err)
}
_, err = connQ.Exec(ctx, sqlUpdateSentrySeq, 0)
if err != nil {
t.Fatal(err)
}

schemaFilter := func(string) bool { return true } // capture changes from all namespaces

Expand All @@ -56,11 +64,6 @@ func Test_repl(t *testing.T) {

t.Log("replication slot started and listening")

connQ, err := pgx.Connect(ctx, connString(host, port, user, pass, dbName, false))
if err != nil {
t.Fatal(err)
}

_, err = connQ.Exec(ctx, `DROP TABLE IF EXISTS blah`)
if err != nil {
t.Fatal(err)
Expand All @@ -71,7 +74,7 @@ func Test_repl(t *testing.T) {
t.Fatal(err)
}

wantCommitHash, _ := hex.DecodeString("9710a1c3b624c5a929425963c7441b0d8cf7d2bcf98aaaf8bc61519543aed1bc")
wantCommitHash, _ := hex.DecodeString("cb390afbf808256307ee0927999805ee3d5af193772e2c9b71823fbc1fe8867f")

var wg sync.WaitGroup
wg.Add(1)
Expand All @@ -91,7 +94,11 @@ func Test_repl(t *testing.T) {
}
cancel()
case err := <-errChan:
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
if errors.Is(err, context.Canceled) {
return
}
if errors.Is(err, context.DeadlineExceeded) {
t.Error("timeout")
return
}
if err != nil {
Expand All @@ -112,6 +119,12 @@ func Test_repl(t *testing.T) {
tx.Exec(ctx, `update blah SET stuff = 6, id = '{13}', val=41 where id = '{10}';`)
tx.Exec(ctx, `update blah SET stuff = 33;`)
tx.Exec(ctx, `delete FROM blah where id = '{11}';`)
// sends on commitChan are only expected from sequenced transactions.
// Bump seq in the sentry table!
_, err = tx.Exec(ctx, sqlUpdateSentrySeq, 1)
if err != nil {
t.Fatal(err)
}

err = tx.Commit(ctx) // this triggers the send
if err != nil {
Expand Down
10 changes: 9 additions & 1 deletion internal/sql/pg/replmon.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ type replMon struct {
done chan struct{}

mtx sync.Mutex
results map[int64][]byte
results map[int64][]byte // results should generally be unused as pg.DB will request a promise before commit
promises map[int64]chan []byte
}

Expand Down Expand Up @@ -97,6 +97,9 @@ func newReplMon(ctx context.Context, host, port, user, pass, dbName string, sche
p <- cHash
delete(rm.promises, seq)
} else {
// This is unexpected since pg.DB will call recvID first. If we are
// in this `else`, it is to be discarded, from another connection.
logger.Warnf("Received commit ID for seq %d BEFORE recvID", seq)
rm.results[seq] = cHash
}
rm.mtx.Unlock()
Expand Down Expand Up @@ -124,11 +127,16 @@ func (rm *replMon) recvID(seq int64) chan []byte {
rm.mtx.Lock()
defer rm.mtx.Unlock()
if cHash, ok := rm.results[seq]; ok {
// The intended use is to do recvID BEFORE
logger.Warnf("recvID with EXISTING result for sequence %d", seq)
delete(rm.results, seq)
c <- cHash
return c
}

if _, have := rm.promises[seq]; have {
logger.Errorf("Commit ID promise for sequence %d ALREADY EXISTS", seq)
}
rm.promises[seq] = c // maybe panic if one already exists, indicating program logic error

return c
Expand Down

0 comments on commit da52bba

Please sign in to comment.