diff --git a/src/transaction/snapshot.rs b/src/transaction/snapshot.rs index a8aa9464..0d1e4803 100644 --- a/src/transaction/snapshot.rs +++ b/src/transaction/snapshot.rs @@ -22,7 +22,7 @@ use crate::Value; /// /// See the [Transaction](struct@crate::Transaction) docs for more information on the methods. #[derive(new)] -pub struct Snapshot> { +pub struct Snapshot = PdRpcClient> { transaction: Transaction, phantom: PhantomData, } diff --git a/src/transaction/transaction.rs b/src/transaction/transaction.rs index 671d6140..10a7f6e7 100644 --- a/src/transaction/transaction.rs +++ b/src/transaction/transaction.rs @@ -2,6 +2,8 @@ use std::iter; use std::marker::PhantomData; +use std::sync::atomic; +use std::sync::atomic::AtomicU8; use std::sync::Arc; use std::time::Instant; @@ -10,7 +12,6 @@ use fail::fail_point; use futures::prelude::*; use log::debug; use log::warn; -use tokio::sync::RwLock; use tokio::time::Duration; use crate::backoff::Backoff; @@ -76,8 +77,8 @@ use crate::Value; /// txn.commit().await.unwrap(); /// # }); /// ``` -pub struct Transaction> { - status: Arc>, +pub struct Transaction = PdRpcClient> { + status: Arc, timestamp: Timestamp, buffer: Buffer, rpc: Arc, @@ -99,7 +100,7 @@ impl> Transaction { TransactionStatus::Active }; Transaction { - status: Arc::new(RwLock::new(status)), + status: Arc::new(AtomicU8::new(status as u8)), timestamp, buffer: Buffer::new(options.is_pessimistic()), rpc, @@ -632,15 +633,16 @@ impl> Transaction { /// ``` pub async fn commit(&mut self) -> Result> { debug!("commiting transaction"); - { - let mut status = self.status.write().await; - if !matches!( - *status, - TransactionStatus::StartedCommit | TransactionStatus::Active - ) { - return Err(Error::OperationAfterCommitError); - } - *status = TransactionStatus::StartedCommit; + if !self.transit_status( + |status| { + matches!( + status, + TransactionStatus::StartedCommit | TransactionStatus::Active + ) + }, + TransactionStatus::StartedCommit, + ) { + return Err(Error::OperationAfterCommitError); } let primary_key = self.buffer.get_primary_key(); @@ -665,8 +667,7 @@ impl> Transaction { .await; if res.is_ok() { - let mut status = self.status.write().await; - *status = TransactionStatus::Committed; + self.set_status(TransactionStatus::Committed); } res } @@ -689,21 +690,18 @@ impl> Transaction { /// ``` pub async fn rollback(&mut self) -> Result<()> { debug!("rolling back transaction"); - { - let status = self.status.read().await; - if !matches!( - *status, - TransactionStatus::StartedRollback - | TransactionStatus::Active - | TransactionStatus::StartedCommit - ) { - return Err(Error::OperationAfterCommitError); - } - } - - { - let mut status = self.status.write().await; - *status = TransactionStatus::StartedRollback; + if !self.transit_status( + |status| { + matches!( + status, + TransactionStatus::StartedRollback + | TransactionStatus::Active + | TransactionStatus::StartedCommit + ) + }, + TransactionStatus::StartedRollback, + ) { + return Err(Error::OperationAfterCommitError); } let primary_key = self.buffer.get_primary_key(); @@ -721,8 +719,7 @@ impl> Transaction { .await; if res.is_ok() { - let mut status = self.status.write().await; - *status = TransactionStatus::Rolledback; + self.set_status(TransactionStatus::Rolledback); } res } @@ -906,8 +903,7 @@ impl> Transaction { /// Checks if the transaction can perform arbitrary operations. async fn check_allow_operation(&self) -> Result<()> { - let status = self.status.read().await; - match *status { + match self.get_status() { TransactionStatus::ReadOnly | TransactionStatus::Active => Ok(()), TransactionStatus::Committed | TransactionStatus::Rolledback @@ -946,9 +942,9 @@ impl> Transaction { loop { tokio::time::sleep(heartbeat_interval).await; { - let status = status.read().await; + let status: TransactionStatus = status.load(atomic::Ordering::Acquire).into(); if matches!( - *status, + status, TransactionStatus::Rolledback | TransactionStatus::Committed | TransactionStatus::Dropped @@ -977,16 +973,42 @@ impl> Transaction { } }); } + + fn get_status(&self) -> TransactionStatus { + self.status.load(atomic::Ordering::Acquire).into() + } + + fn set_status(&self, status: TransactionStatus) { + self.status.store(status as u8, atomic::Ordering::Release); + } + + fn transit_status(&self, check_status: F, next: TransactionStatus) -> bool + where + F: Fn(TransactionStatus) -> bool, + { + let mut current = self.get_status(); + while check_status(current) { + match self.status.compare_exchange_weak( + current as u8, + next as u8, + atomic::Ordering::AcqRel, + atomic::Ordering::Acquire, + ) { + Ok(_) => return true, + Err(x) => current = x.into(), + } + } + false + } } -impl Drop for Transaction { +impl> Drop for Transaction { fn drop(&mut self) { debug!("dropping transaction"); if std::thread::panicking() { return; } - let mut status = futures::executor::block_on(self.status.write()); - if *status == TransactionStatus::Active { + if self.get_status() == TransactionStatus::Active { match self.options.check_level { CheckLevel::Panic => { panic!("Dropping an active transaction. Consider commit or rollback it.") @@ -998,7 +1020,7 @@ impl Drop for Transaction { CheckLevel::None => {} } } - *status = TransactionStatus::Dropped; + self.set_status(TransactionStatus::Dropped); } } @@ -1432,22 +1454,38 @@ impl Committer { } } -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone, Copy)] +#[repr(u8)] enum TransactionStatus { /// The transaction is read-only [`Snapshot`](super::Snapshot), no need to commit or rollback or panic on drop. - ReadOnly, + ReadOnly = 0, /// The transaction have not been committed or rolled back. - Active, + Active = 1, /// The transaction has committed. - Committed, + Committed = 2, /// The transaction has tried to commit. Only `commit` is allowed. - StartedCommit, + StartedCommit = 3, /// The transaction has rolled back. - Rolledback, + Rolledback = 4, /// The transaction has tried to rollback. Only `rollback` is allowed. - StartedRollback, + StartedRollback = 5, /// The transaction has been dropped. - Dropped, + Dropped = 6, +} + +impl From for TransactionStatus { + fn from(num: u8) -> Self { + match num { + 0 => TransactionStatus::ReadOnly, + 1 => TransactionStatus::Active, + 2 => TransactionStatus::Committed, + 3 => TransactionStatus::StartedCommit, + 4 => TransactionStatus::Rolledback, + 5 => TransactionStatus::StartedRollback, + 6 => TransactionStatus::Dropped, + _ => panic!("Unknown transaction status {}", num), + } + } } #[cfg(test)]