From dc19f2028565a94d8fcdcf29ded702948039206a Mon Sep 17 00:00:00 2001
From: mdecimus <mauro@stalw.art>
Date: Sat, 11 Jan 2025 11:10:01 +0100
Subject: [PATCH] Fix tracking locked queue ids (closes #1066)

---
 crates/common/src/ipc.rs                 |  15 +-
 crates/main/Cargo.toml                   |   4 +-
 crates/smtp/src/outbound/delivery.rs     |  50 +++---
 crates/smtp/src/queue/manager.rs         | 186 +++++++++++++----------
 tests/resources/smtp/antispam/dmarc.test |  11 ++
 tests/src/smtp/inbound/mod.rs            |   6 +-
 tests/src/smtp/outbound/lmtp.rs          |   2 +-
 tests/src/smtp/outbound/smtp.rs          |   2 +-
 tests/src/smtp/queue/retry.rs            |   8 +-
 9 files changed, 158 insertions(+), 126 deletions(-)

diff --git a/crates/common/src/ipc.rs b/crates/common/src/ipc.rs
index d39f95ccc..c8a791279 100644
--- a/crates/common/src/ipc.rs
+++ b/crates/common/src/ipc.rs
@@ -128,16 +128,21 @@ pub struct EncryptionKeys {
 pub enum QueueEvent {
     Refresh(Option<u64>),
     WorkerDone(u64),
-    OnHold(OnHold<QueuedMessage>),
+    OnHold { queue_id: u64, status: OnHold },
     Paused(bool),
     Stop,
 }
 
 #[derive(Debug)]
-pub struct OnHold<T> {
-    pub next_due: Option<u64>,
-    pub limiters: Vec<ConcurrencyLimiter>,
-    pub message: T,
+pub enum OnHold {
+    InFlight,
+    Locked {
+        until: u64,
+    },
+    ConcurrencyLimited {
+        limiters: Vec<ConcurrencyLimiter>,
+        next_due: Option<u64>,
+    },
 }
 
 #[derive(Debug, Clone, Copy)]
diff --git a/crates/main/Cargo.toml b/crates/main/Cargo.toml
index 9b3965da6..264d902d6 100644
--- a/crates/main/Cargo.toml
+++ b/crates/main/Cargo.toml
@@ -34,8 +34,8 @@ tokio = { version = "1.23", features = ["full"] }
 jemallocator = "0.5.0"
 
 [features]
-default = ["sqlite", "postgres", "mysql", "rocks", "elastic", "s3", "redis", "azure", "enterprise"]
-#default = ["rocks", "enterprise"]
+#default = ["sqlite", "postgres", "mysql", "rocks", "elastic", "s3", "redis", "azure", "enterprise"]
+default = ["rocks", "enterprise"]
 sqlite = ["store/sqlite"]
 foundationdb = ["store/foundation", "common/foundation"]
 postgres = ["store/postgres"]
diff --git a/crates/smtp/src/outbound/delivery.rs b/crates/smtp/src/outbound/delivery.rs
index fb0c871c7..a38771573 100644
--- a/crates/smtp/src/outbound/delivery.rs
+++ b/crates/smtp/src/outbound/delivery.rs
@@ -18,7 +18,7 @@ use common::config::{
     server::ServerProtocol,
     smtp::{queue::RequireOptional, report::AggregateFrequency},
 };
-use common::ipc::{OnHold, PolicyType, QueueEvent, QueuedMessage, TlsEvent};
+use common::ipc::{OnHold, PolicyType, QueueEvent, TlsEvent};
 use common::Server;
 use mail_auth::{
     mta_sts::TlsRpt,
@@ -41,16 +41,7 @@ use super::{lookup::ToNextHop, mta_sts, session::SessionParams, NextHop, TlsStra
 use crate::queue::{throttle, DeliveryAttempt, Domain, Error, QueueEnvelope, Status};
 
 impl DeliveryAttempt {
-    pub fn try_deliver(mut self, server: Server) -> Option<OnHold<QueuedMessage>> {
-        // Global concurrency limiter
-        if let Err(limiter) = server.is_outbound_allowed(&mut self.in_flight) {
-            return Some(OnHold {
-                next_due: None,
-                limiters: vec![limiter],
-                message: self.event,
-            });
-        }
-
+    pub fn try_deliver(self, server: Server) {
         tokio::spawn(async move {
             // Lock queue event
             let queue_id = self.event.queue_id;
@@ -123,11 +114,12 @@ impl DeliveryAttempt {
                     QueueEvent::WorkerDone(queue_id)
                 }
             } else {
-                QueueEvent::OnHold(OnHold {
-                    next_due: Some(LOCK_EXPIRY + 1),
-                    limiters: vec![],
-                    message: self.event,
-                })
+                QueueEvent::OnHold {
+                    queue_id: self.event.queue_id,
+                    status: OnHold::Locked {
+                        until: now() + LOCK_EXPIRY + 1,
+                    },
+                }
             };
 
             // Notify queue manager
@@ -139,8 +131,6 @@ impl DeliveryAttempt {
                 );
             }
         });
-
-        None
     }
 
     async fn deliver_task(mut self, server: Server, mut message: Message) -> QueueEvent {
@@ -193,11 +183,13 @@ impl DeliveryAttempt {
                             SpanId = span_id,
                         );
 
-                        QueueEvent::OnHold(OnHold {
-                            next_due,
-                            limiters: vec![limiter],
-                            message: self.event,
-                        })
+                        QueueEvent::OnHold {
+                            queue_id: self.event.queue_id,
+                            status: OnHold::ConcurrencyLimited {
+                                next_due,
+                                limiters: vec![limiter],
+                            },
+                        }
                     }
                     throttle::Error::Rate { retry_at } => {
                         // Save changes to disk
@@ -1339,11 +1331,13 @@ impl DeliveryAttempt {
                 SpanId = span_id,
             );
 
-            QueueEvent::OnHold(OnHold {
-                next_due,
-                limiters: on_hold,
-                message: self.event,
-            })
+            QueueEvent::OnHold {
+                queue_id: self.event.queue_id,
+                status: OnHold::ConcurrencyLimited {
+                    next_due,
+                    limiters: on_hold,
+                },
+            }
         } else if let Some(due) = message.next_event() {
             trc::event!(
                 Queue(trc::QueueEvent::Rescheduled),
diff --git a/crates/smtp/src/queue/manager.rs b/crates/smtp/src/queue/manager.rs
index 8102a701c..5ca62de5c 100644
--- a/crates/smtp/src/queue/manager.rs
+++ b/crates/smtp/src/queue/manager.rs
@@ -9,24 +9,25 @@ use std::{
     time::{Duration, Instant},
 };
 
-use ahash::AHashSet;
+use ahash::{AHashMap, AHashSet};
 use common::{
     core::BuildServer,
-    ipc::{OnHold, QueueEvent, QueuedMessage},
+    ipc::{OnHold, QueueEvent},
     Inner,
 };
+use rand::seq::SliceRandom;
 use store::write::now;
 use tokio::sync::mpsc;
 
 use super::{
     spool::{SmtpSpool, QUEUE_REFRESH},
+    throttle::IsAllowed,
     DeliveryAttempt, Message, QueueId, Status,
 };
 
 pub struct Queue {
     pub core: Arc<Inner>,
-    pub on_hold: Vec<OnHold<QueuedMessage>>,
-    pub in_flight: AHashSet<QueueId>,
+    pub on_hold: AHashMap<QueueId, OnHold>,
     pub next_wake_up: Instant,
     pub rx: mpsc::Receiver<QueueEvent>,
 }
@@ -39,12 +40,13 @@ impl SpawnQueue for mpsc::Receiver<QueueEvent> {
     }
 }
 
+const CLEANUP_INTERVAL: Duration = Duration::from_secs(10 * 60);
+
 impl Queue {
     pub fn new(core: Arc<Inner>, rx: mpsc::Receiver<QueueEvent>) -> Self {
         Queue {
             core,
-            on_hold: Vec::with_capacity(128),
-            in_flight: AHashSet::with_capacity(128),
+            on_hold: AHashMap::with_capacity(128),
             next_wake_up: Instant::now(),
             rx,
         }
@@ -52,9 +54,10 @@ impl Queue {
 
     pub async fn start(&mut self) {
         let mut is_paused = false;
+        let mut next_cleanup = Instant::now() + CLEANUP_INTERVAL;
 
         loop {
-            let (on_hold, refresh_queue) = match tokio::time::timeout(
+            let refresh_queue = match tokio::time::timeout(
                 self.next_wake_up.duration_since(Instant::now()),
                 self.rx.recv(),
             )
@@ -62,20 +65,24 @@ impl Queue {
             {
                 Ok(Some(QueueEvent::Refresh(queue_id))) => {
                     if let Some(queue_id) = queue_id {
-                        self.in_flight.remove(&queue_id);
+                        self.on_hold.remove(&queue_id);
                     }
-
-                    (None, true)
+                    true
                 }
                 Ok(Some(QueueEvent::WorkerDone(queue_id))) => {
-                    self.in_flight.remove(&queue_id);
-
-                    (None, false)
+                    self.on_hold.remove(&queue_id);
+                    !self.on_hold.is_empty()
                 }
-                Ok(Some(QueueEvent::OnHold(on_hold))) => {
-                    self.in_flight.remove(&on_hold.message.queue_id);
+                Ok(Some(QueueEvent::OnHold { queue_id, status })) => {
+                    if let OnHold::Locked { until } = &status {
+                        let due_in = Instant::now() + Duration::from_secs(*until - now());
+                        if due_in < self.next_wake_up {
+                            self.next_wake_up = due_in;
+                        }
+                    }
 
-                    (on_hold.into(), false)
+                    self.on_hold.insert(queue_id, status);
+                    self.on_hold.len() > 1
                 }
                 Ok(Some(QueueEvent::Paused(paused))) => {
                     self.core
@@ -83,98 +90,115 @@ impl Queue {
                         .queue_status
                         .store(!paused, Ordering::Relaxed);
                     is_paused = paused;
-                    (None, false)
+                    false
                 }
-                Err(_) => (None, true),
+                Err(_) => true,
                 Ok(Some(QueueEvent::Stop)) | Ok(None) => {
                     break;
                 }
             };
 
             if !is_paused {
-                // Deliver any concurrency limited messages
-                let server = self.core.build_server();
-                while let Some(queue_event) = self.next_on_hold() {
-                    if let Some(message) =
-                        DeliveryAttempt::new(queue_event).try_deliver(server.clone())
-                    {
-                        self.on_hold(message);
-                    } else {
-                        self.in_flight.insert(queue_event.queue_id);
-                    }
-                }
-
                 // Deliver scheduled messages
                 if refresh_queue || self.next_wake_up <= Instant::now() {
                     let now = now();
                     let mut next_wake_up = QUEUE_REFRESH;
-                    for queue_event in server.next_event().await {
-                        match self.is_on_hold(queue_event.queue_id) {
-                            None => {
-                                if queue_event.due <= now {
-                                    if !self.in_flight.contains(&queue_event.queue_id) {
-                                        if let Some(message) = DeliveryAttempt::new(queue_event)
-                                            .try_deliver(server.clone())
+                    let server = self.core.build_server();
+
+                    // Process queue events
+                    let mut queue_events = server.next_event().await;
+
+                    if queue_events.len() > 5 {
+                        queue_events.shuffle(&mut rand::thread_rng());
+                    }
+
+                    for queue_event in &queue_events {
+                        if queue_event.due <= now {
+                            // Check if the message is still on hold
+                            if let Some(on_hold) = self.on_hold.get(&queue_event.queue_id) {
+                                match on_hold {
+                                    OnHold::Locked { until } => {
+                                        if *until > now {
+                                            let due_in = *until - now;
+                                            if due_in < next_wake_up {
+                                                next_wake_up = due_in;
+                                            }
+                                            continue;
+                                        }
+                                    }
+                                    OnHold::ConcurrencyLimited { limiters, next_due } => {
+                                        if !(limiters.iter().any(|l| {
+                                            l.concurrent.load(Ordering::Relaxed) < l.max_concurrent
+                                        }) || next_due.map_or(false, |due| due <= now))
                                         {
-                                            self.on_hold(message);
-                                        } else {
-                                            self.in_flight.insert(queue_event.queue_id);
+                                            continue;
                                         }
                                     }
-                                } else {
-                                    let due_in = queue_event.due - now;
-                                    if due_in < next_wake_up {
-                                        next_wake_up = due_in;
+                                    OnHold::InFlight => continue,
+                                }
+
+                                self.on_hold.remove(&queue_event.queue_id);
+                            }
+
+                            // Enforce global concurrency limits
+                            let mut in_flight = Vec::new();
+                            match server.is_outbound_allowed(&mut in_flight) {
+                                Ok(_) => {
+                                    self.on_hold.insert(queue_event.queue_id, OnHold::InFlight);
+                                    DeliveryAttempt {
+                                        in_flight,
+                                        event: *queue_event,
                                     }
+                                    .try_deliver(server.clone());
+                                }
+
+                                Err(limiter) => {
+                                    self.on_hold.insert(
+                                        queue_event.queue_id,
+                                        OnHold::ConcurrencyLimited {
+                                            limiters: vec![limiter],
+                                            next_due: None,
+                                        },
+                                    );
                                 }
                             }
-                            Some(on_hold)
-                                if on_hold.limiters.is_empty()
-                                    && on_hold.next_due.map_or(false, |due| due < next_wake_up) =>
-                            {
-                                next_wake_up = on_hold.next_due.unwrap();
+                        } else {
+                            let due_in = queue_event.due - now;
+                            if due_in < next_wake_up {
+                                next_wake_up = due_in;
                             }
-                            _ => (),
                         }
                     }
-                    self.next_wake_up = Instant::now() + Duration::from_secs(next_wake_up);
+
+                    // Remove expired locks
+                    let now = Instant::now();
+                    if next_cleanup <= now {
+                        next_cleanup = now + CLEANUP_INTERVAL;
+
+                        if !self.on_hold.is_empty() {
+                            let active_queue_ids = queue_events
+                                .into_iter()
+                                .map(|e| e.queue_id)
+                                .collect::<AHashSet<_>>();
+                            let now = store::write::now();
+                            self.on_hold.retain(|queue_id, status| match status {
+                                OnHold::InFlight => true,
+                                OnHold::Locked { until } => *until > now,
+                                OnHold::ConcurrencyLimited { .. } => {
+                                    active_queue_ids.contains(queue_id)
+                                }
+                            });
+                        }
+                    }
+
+                    self.next_wake_up = now + Duration::from_secs(next_wake_up);
                 }
             } else {
                 // Queue is paused
                 self.next_wake_up = Instant::now() + Duration::from_secs(86400);
             }
-
-            // Add message on hold
-            if let Some(on_hold) = on_hold {
-                self.on_hold(on_hold);
-            }
         }
     }
-
-    pub fn is_on_hold(&self, queue_id: QueueId) -> Option<&OnHold<QueuedMessage>> {
-        self.on_hold.iter().find(|o| o.message.queue_id == queue_id)
-    }
-
-    pub fn on_hold(&mut self, message: OnHold<QueuedMessage>) {
-        self.on_hold.push(OnHold {
-            next_due: message.next_due,
-            limiters: message.limiters,
-            message: message.message,
-        });
-    }
-
-    pub fn next_on_hold(&mut self) -> Option<QueuedMessage> {
-        let now = now();
-        self.on_hold
-            .iter()
-            .position(|o| {
-                o.limiters
-                    .iter()
-                    .any(|l| l.concurrent.load(Ordering::Relaxed) < l.max_concurrent)
-                    || o.next_due.map_or(false, |due| due <= now)
-            })
-            .map(|pos| self.on_hold.remove(pos).message)
-    }
 }
 
 impl Message {
diff --git a/tests/resources/smtp/antispam/dmarc.test b/tests/resources/smtp/antispam/dmarc.test
index 3f2dcade2..9e3a77e31 100644
--- a/tests/resources/smtp/antispam/dmarc.test
+++ b/tests/resources/smtp/antispam/dmarc.test
@@ -127,3 +127,14 @@ From: user@spf-dkim-allow.org
 Subject: test
 
 Test
+<!-- NEXT TEST -->
+spf.result pass
+dkim.result pass
+arc.result pass
+dmarc.result pass
+envelope_from hello@stalw.art
+expect TRUSTED_DOMAIN DMARC_POLICY_ALLOW DKIM_ALLOW SPF_ALLOW ARC_ALLOW
+
+From: <hello@stalw.art>
+
+Test
diff --git a/tests/src/smtp/inbound/mod.rs b/tests/src/smtp/inbound/mod.rs
index ded084666..a04e78e79 100644
--- a/tests/src/smtp/inbound/mod.rs
+++ b/tests/src/smtp/inbound/mod.rs
@@ -302,7 +302,7 @@ pub trait TestQueueEvent {
     fn assert_refresh(self);
     fn assert_done(self);
     fn assert_refresh_or_done(self);
-    fn unwrap_on_hold(self) -> OnHold<QueuedMessage>;
+    fn unwrap_on_hold(self) -> OnHold;
 }
 
 impl TestQueueEvent for QueueEvent {
@@ -327,9 +327,9 @@ impl TestQueueEvent for QueueEvent {
         }
     }
 
-    fn unwrap_on_hold(self) -> OnHold<QueuedMessage> {
+    fn unwrap_on_hold(self) -> OnHold {
         match self {
-            QueueEvent::OnHold(value) => value,
+            QueueEvent::OnHold { status, .. } => status,
             e => panic!("Unexpected event: {e:?}"),
         }
     }
diff --git a/tests/src/smtp/outbound/lmtp.rs b/tests/src/smtp/outbound/lmtp.rs
index 0feb4b741..eee466fc3 100644
--- a/tests/src/smtp/outbound/lmtp.rs
+++ b/tests/src/smtp/outbound/lmtp.rs
@@ -108,7 +108,7 @@ async fn lmtp_delivery() {
     loop {
         match local.queue_receiver.try_read_event().await {
             Some(QueueEvent::Refresh(_) | QueueEvent::WorkerDone(_)) => {}
-            Some(QueueEvent::OnHold(_)) | Some(QueueEvent::Paused(_)) => unreachable!(),
+            Some(QueueEvent::OnHold { .. }) | Some(QueueEvent::Paused(_)) => unreachable!(),
             None | Some(QueueEvent::Stop) => break,
         }
 
diff --git a/tests/src/smtp/outbound/smtp.rs b/tests/src/smtp/outbound/smtp.rs
index 7de05df03..1e65339a3 100644
--- a/tests/src/smtp/outbound/smtp.rs
+++ b/tests/src/smtp/outbound/smtp.rs
@@ -137,7 +137,7 @@ async fn smtp_delivery() {
     loop {
         match local.queue_receiver.try_read_event().await {
             Some(QueueEvent::Refresh(_) | QueueEvent::WorkerDone(_)) => {}
-            Some(QueueEvent::OnHold(_)) | Some(QueueEvent::Paused(_)) => unreachable!(),
+            Some(QueueEvent::OnHold { .. }) | Some(QueueEvent::Paused(_)) => unreachable!(),
             None | Some(QueueEvent::Stop) => break,
         }
 
diff --git a/tests/src/smtp/queue/retry.rs b/tests/src/smtp/queue/retry.rs
index 8b97d3fa3..f13c16c94 100644
--- a/tests/src/smtp/queue/retry.rs
+++ b/tests/src/smtp/queue/retry.rs
@@ -93,8 +93,8 @@ async fn queue_retry() {
             Some(QueueEvent::Refresh(Some(queue_id)) | QueueEvent::WorkerDone(queue_id)) => {
                 in_fight.remove(&queue_id);
             }
-            Some(QueueEvent::OnHold(event)) => {
-                panic!("unexpected on hold event: {:?}", event);
+            Some(QueueEvent::OnHold { queue_id, status }) => {
+                panic!("unexpected on hold event {queue_id}: {status:?}");
             }
             Some(QueueEvent::Refresh(None)) => (),
             None | Some(QueueEvent::Stop) | Some(QueueEvent::Paused(_)) => break,
@@ -121,9 +121,7 @@ async fn queue_retry() {
             } else {
                 retries.push(event.due.saturating_sub(now));
                 in_fight.insert(event.queue_id);
-                assert!(DeliveryAttempt::new(event)
-                    .try_deliver(core.clone())
-                    .is_none());
+                DeliveryAttempt::new(event).try_deliver(core.clone());
                 tokio::time::sleep(Duration::from_millis(100)).await;
             }
         }