diff --git a/examples/counter.rs b/examples/counter.rs index 7cec4ee..96e8c85 100644 --- a/examples/counter.rs +++ b/examples/counter.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use pg_task::{Step, StepResult}; +use pg_task::{NextStep, Step, StepResult}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::{env, time::Duration}; @@ -21,7 +21,7 @@ async fn main() -> anyhow::Result<()> { init_logging()?; // Let's schedule a few tasks - pg_task::enqueue(&db, &Tasks::Count(Start { up_to: 2 }.into())).await?; + pg_task::enqueue(&db, &Tasks::Count(Start { up_to: 1000 }.into())).await?; // And run a worker pg_task::Worker::::new(db).run().await; @@ -35,16 +35,13 @@ pub struct Start { } #[async_trait] impl Step for Start { - async fn step(self, _db: &PgPool) -> StepResult> { + async fn step(self, _db: &PgPool) -> StepResult { println!("1..{}: start", self.up_to); - Ok(Some( - Proceed { - up_to: self.up_to, - started_at: Utc::now(), - cur: 0, - } - .into(), - )) + NextStep::now(Proceed { + up_to: self.up_to, + started_at: Utc::now(), + cur: 0, + }) } } @@ -59,7 +56,7 @@ impl Step for Proceed { const RETRY_LIMIT: i32 = 5; const RETRY_DELAY: Duration = Duration::from_secs(1); - async fn step(self, _db: &PgPool) -> StepResult> { + async fn step(self, _db: &PgPool) -> StepResult { // return Err(anyhow::anyhow!("bailing").into()); let Self { up_to, @@ -69,16 +66,13 @@ impl Step for Proceed { cur += 1; // println!("1..{up_to}: {cur}"); if cur < up_to { - Ok(Some( - Proceed { - up_to, - started_at, - cur, - } - .into(), - )) + NextStep::now(Proceed { + up_to, + started_at, + cur, + }) } else { - Ok(Some(Finish { up_to, started_at }.into())) + NextStep::now(Finish { up_to, started_at }) } } } @@ -90,7 +84,7 @@ pub struct Finish { } #[async_trait] impl Step for Finish { - async fn step(self, _db: &PgPool) -> StepResult> { + async fn step(self, _db: &PgPool) -> StepResult { let took = Utc::now() - self.started_at; let secs = num_seconds(took); let per_sec = self.up_to as f64 / secs; @@ -100,7 +94,7 @@ impl Step for Finish { secs, per_sec.round() ); - Ok(None) + NextStep::none() } } diff --git a/examples/delay.rs b/examples/delay.rs new file mode 100644 index 0000000..d2be090 --- /dev/null +++ b/examples/delay.rs @@ -0,0 +1,53 @@ +use async_trait::async_trait; +use pg_task::{NextStep, Step, StepResult}; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; +use std::{env, time::Duration}; + +// It wraps the task step into an enum which proxies necessary methods +pg_task::task!(FooBar { Foo, Bar }); + +// Also we need a enum representing all the possible tasks +pg_task::scheduler!(Tasks { FooBar }); + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let db = connect().await?; + + // Let's schedule a few tasks + for delay in [3, 1, 2] { + pg_task::enqueue(&db, &Tasks::FooBar(Foo(delay).into())).await?; + } + + // And run a worker + pg_task::Worker::::new(db).run().await; + + Ok(()) +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Foo(u64); +#[async_trait] +impl Step for Foo { + async fn step(self, _db: &PgPool) -> StepResult { + println!("Sleeping for {} sec", self.0); + NextStep::delay(Bar(self.0), Duration::from_secs(self.0)) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Bar(u64); +#[async_trait] +impl Step for Bar { + async fn step(self, _db: &PgPool) -> StepResult { + println!("Woke up after {} sec", self.0); + NextStep::none() + } +} + +async fn connect() -> anyhow::Result { + dotenv::dotenv().ok(); + let db = PgPool::connect(&env::var("DATABASE_URL")?).await?; + sqlx::migrate!().run(&db).await?; + Ok(db) +} diff --git a/src/error.rs b/src/error.rs index 9fe55ca..6152913 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,3 +1,4 @@ +use crate::NextStep; use std::{error::Error as StdError, result::Result as StdResult}; use tracing::error; @@ -19,4 +20,4 @@ pub type Result = StdResult; pub type StepError = Box; /// Result returning from task steps -pub type StepResult = StdResult; +pub type StepResult = StdResult, StepError>; diff --git a/src/lib.rs b/src/lib.rs index 5d3ad95..dd923e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,10 +5,13 @@ mod error; mod macros; +mod next_step; mod traits; +mod util; mod worker; pub use error::{Error, Result, StepError, StepResult}; +pub use next_step::NextStep; pub use traits::{Scheduler, Step}; pub use worker::Worker; diff --git a/src/macros.rs b/src/macros.rs index 8b3fabe..72dabb1 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -17,10 +17,16 @@ macro_rules! task { #[async_trait::async_trait] impl $crate::Step<$enum> for $enum { - async fn step(self, db: &sqlx::PgPool) -> $crate::StepResult> { - Ok(match self { - $(Self::$variant(inner) => inner.step(db).await?.map(Into::into),)* - }) + async fn step(self, db: &sqlx::PgPool) -> $crate::StepResult<$enum> { + match self { + $(Self::$variant(inner) => inner.step(db).await.map(|next| + match next { + NextStep::None => NextStep::None, + NextStep::Now(x) => NextStep::Now(x.into()), + NextStep::Delayed(x, d) => NextStep::Delayed(x.into(), d), + } + ),)* + } } fn retry_limit(&self) -> i32 { diff --git a/src/next_step.rs b/src/next_step.rs new file mode 100644 index 0000000..36a9d7f --- /dev/null +++ b/src/next_step.rs @@ -0,0 +1,29 @@ +use crate::StepResult; +use std::time::Duration; + +/// Represents next step of the task +pub enum NextStep { + /// The task is done + None, + /// Run the next step immediately + Now(T), + /// Delay the next step + Delayed(T, Duration), +} + +impl NextStep { + /// The task is done + pub fn none() -> StepResult { + Ok(Self::None) + } + + /// Run the next step immediately + pub fn now(step: impl Into) -> StepResult { + Ok(Self::Now(step.into())) + } + + /// Delay the next step + pub fn delay(step: impl Into, delay: Duration) -> StepResult { + Ok(Self::Delayed(step.into(), delay)) + } +} diff --git a/src/traits.rs b/src/traits.rs index 4847847..7e0f42e 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,4 +1,4 @@ -use crate::{Error, Result, StepResult}; +use crate::{Error, StepResult}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use serde::{de::DeserializeOwned, Serialize}; @@ -19,7 +19,7 @@ where const RETRY_DELAY: Duration = Duration::from_secs(1); /// Processes the current step and returns the next if any - async fn step(self, db: &PgPool) -> StepResult>; + async fn step(self, db: &PgPool) -> StepResult; /// Proxies the `RETRY` const, doesn't mean to be changed in impls fn retry_limit(&self) -> i32 { @@ -36,19 +36,19 @@ where #[async_trait] pub trait Scheduler: fmt::Debug + DeserializeOwned + Serialize + Sized + Sync { /// Enqueues the task to be run immediately - async fn enqueue(&self, db: &PgPool) -> Result { + async fn enqueue(&self, db: &PgPool) -> crate::Result { self.schedule(db, Utc::now()).await } /// Schedules a task to be run after a specified delay - async fn delay(&self, db: &PgPool, delay: Duration) -> Result { + async fn delay(&self, db: &PgPool, delay: Duration) -> crate::Result { let delay = chrono::Duration::from_std(delay).unwrap_or_else(|_| chrono::Duration::max_value()); self.schedule(db, Utc::now() + delay).await } /// Schedules a task to run at a specified time in the future - async fn schedule(&self, db: &PgPool, at: DateTime) -> Result { + async fn schedule(&self, db: &PgPool, at: DateTime) -> crate::Result { let step = serde_json::to_string(self).map_err(Error::SerializeStep)?; sqlx::query!( "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2) RETURNING id", diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..b8feeb9 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,11 @@ +/// Converts a chrono duration to std, it uses absolute value of the chrono duration +pub fn chrono_duration_to_std(chrono_duration: chrono::Duration) -> std::time::Duration { + let seconds = chrono_duration.num_seconds(); + let nanos = chrono_duration.num_nanoseconds().unwrap_or(0) % 1_000_000_000; + std::time::Duration::new(seconds.unsigned_abs(), nanos.unsigned_abs() as u32) +} + +/// Converts a std duration to chrono +pub fn std_duration_to_chrono(std_duration: std::time::Duration) -> chrono::Duration { + chrono::Duration::from_std(std_duration).unwrap_or_else(|_| chrono::Duration::max_value()) +} diff --git a/src/worker.rs b/src/worker.rs index adf4e25..aa6b080 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,4 +1,4 @@ -use crate::{Step, StepError}; +use crate::{util, NextStep, Step, StepError}; use chrono::{DateTime, Utc}; use code_path::code_path; use sqlx::{ @@ -65,7 +65,7 @@ impl Task { if delay <= chrono::Duration::zero() { Duration::ZERO } else { - chrono_duration_to_std(delay) + util::chrono_duration_to_std(delay) } } } @@ -129,8 +129,9 @@ impl> Worker { self.process_error(id, tried, retry_limit, retry_delay, e) .await? } - Ok(None) => self.finish_task(id).await?, - Ok(Some(step)) => self.update_task_step(id, step).await?, + Ok(NextStep::None) => self.finish_task(id).await?, + Ok(NextStep::Now(step)) => self.update_task_step(id, step, Duration::ZERO).await?, + Ok(NextStep::Delayed(step, delay)) => self.update_task_step(id, step, delay).await?, }; Ok(()) } @@ -171,7 +172,9 @@ impl> Worker { tx.commit() .await .map_err(sqlx_error!("commit on wait for a period"))?; - waiter.wait_for(chrono_duration_to_std(time_to_run)).await?; + waiter + .wait_for(util::chrono_duration_to_std(time_to_run)) + .await?; } else { tx.commit() .await @@ -182,7 +185,7 @@ impl> Worker { } /// Updates the tasks step - async fn update_task_step(&self, task_id: Uuid, step: S) -> Result<()> { + async fn update_task_step(&self, task_id: Uuid, step: S, delay: Duration) -> Result<()> { let step = match serde_json::to_string(&step) .map_err(|e| ErrorReport::SerializeStep(e, format!("{:?}", step))) { @@ -209,7 +212,7 @@ impl> Worker { ", task_id, step, - Utc::now(), + Utc::now() + util::std_duration_to_chrono(delay), ) .execute(&self.db) .await @@ -403,13 +406,6 @@ impl TaskWaiter { } } -/// Converts a chrono duration to std, it uses absolute value of the chrono duration -fn chrono_duration_to_std(chrono_duration: chrono::Duration) -> std::time::Duration { - let seconds = chrono_duration.num_seconds(); - let nanos = chrono_duration.num_nanoseconds().unwrap_or(0) % 1_000_000_000; - std::time::Duration::new(seconds.unsigned_abs(), nanos.unsigned_abs() as u32) -} - /// Returns the ordinal string of a given integer fn ordinal(n: i32) -> String { match n.abs() {