Skip to content

Commit

Permalink
Merge pull request #196 from semiotic-ai/gusinacio/state-machine-refa…
Browse files Browse the repository at this point in the history
…ctor

refactor!: use typestate for receivedreceipt
  • Loading branch information
gusinacio authored Dec 22, 2023
2 parents d7f5939 + 8616d13 commit d841a44
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 428 deletions.
2 changes: 1 addition & 1 deletion tap_core/src/adapters/mock/receipt_checks_adapter_mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl ReceiptChecksAdapter for ReceiptChecksAdapterMock {
Ok(receipt_storage
.iter()
.all(|(stored_receipt_id, stored_receipt)| {
(stored_receipt.signed_receipt.message != receipt.message)
(stored_receipt.signed_receipt().message != receipt.message)
|| *stored_receipt_id == receipt_id
}))
}
Expand Down
18 changes: 10 additions & 8 deletions tap_core/src/adapters/mock/receipt_storage_adapter_mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use async_trait::async_trait;
use tokio::sync::RwLock;

use crate::{
adapters::receipt_storage_adapter::{safe_truncate_receipts, ReceiptRead, ReceiptStore},
adapters::receipt_storage_adapter::{
safe_truncate_receipts, ReceiptRead, ReceiptStore, StoredReceipt,
},
tap_receipt::ReceivedReceipt,
};

Expand Down Expand Up @@ -44,15 +46,15 @@ impl ReceiptStorageAdapterMock {
Ok(receipt_storage
.iter()
.filter(|(_, rx_receipt)| {
rx_receipt.signed_receipt.message.timestamp_ns == timestamp_ns
rx_receipt.signed_receipt().message.timestamp_ns == timestamp_ns
})
.map(|(&id, rx_receipt)| (id, rx_receipt.clone()))
.collect())
}
pub async fn retrieve_receipts_upto_timestamp(
&self,
timestamp_ns: u64,
) -> Result<Vec<(u64, ReceivedReceipt)>, AdapterErrorMock> {
) -> Result<Vec<StoredReceipt>, AdapterErrorMock> {
self.retrieve_receipts_in_timestamp_range(..=timestamp_ns, None)
.await
}
Expand Down Expand Up @@ -117,7 +119,7 @@ impl ReceiptStore for ReceiptStorageAdapterMock {
) -> Result<(), Self::AdapterError> {
let mut receipt_storage = self.receipt_storage.write().await;
receipt_storage.retain(|_, rx_receipt| {
!timestamp_ns.contains(&rx_receipt.signed_receipt.message.timestamp_ns)
!timestamp_ns.contains(&rx_receipt.signed_receipt().message.timestamp_ns)
});
Ok(())
}
Expand All @@ -130,22 +132,22 @@ impl ReceiptRead for ReceiptStorageAdapterMock {
&self,
timestamp_range_ns: R,
limit: Option<u64>,
) -> Result<Vec<(u64, ReceivedReceipt)>, Self::AdapterError> {
) -> Result<Vec<StoredReceipt>, Self::AdapterError> {
let receipt_storage = self.receipt_storage.read().await;
let mut receipts_in_range: Vec<(u64, ReceivedReceipt)> = receipt_storage
.iter()
.filter(|(_, rx_receipt)| {
timestamp_range_ns.contains(&rx_receipt.signed_receipt.message.timestamp_ns)
timestamp_range_ns.contains(&rx_receipt.signed_receipt().message.timestamp_ns)
})
.map(|(&id, rx_receipt)| (id, rx_receipt.clone()))
.collect();

if limit.is_some_and(|limit| receipts_in_range.len() > limit as usize) {
safe_truncate_receipts(&mut receipts_in_range, limit.unwrap());

Ok(receipts_in_range)
Ok(receipts_in_range.into_iter().map(|r| r.into()).collect())
} else {
Ok(receipts_in_range)
Ok(receipts_in_range.into_iter().map(|r| r.into()).collect())
}
}
}
24 changes: 19 additions & 5 deletions tap_core/src/adapters/receipt_storage_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,22 @@ pub trait ReceiptRead {
&self,
timestamp_range_ns: R,
limit: Option<u64>,
) -> Result<Vec<(u64, ReceivedReceipt)>, Self::AdapterError>;
) -> Result<Vec<StoredReceipt>, Self::AdapterError>;
}

pub struct StoredReceipt {
pub receipt_id: u64,
pub receipt: ReceivedReceipt,
}

impl From<(u64, ReceivedReceipt)> for StoredReceipt {
fn from((receipt_id, receipt): (u64, ReceivedReceipt)) -> Self {
Self {
receipt_id,
receipt,
}
}
}
/// See [`ReceiptStorageAdapter::retrieve_receipts_in_timestamp_range()`] for details.
///
/// WARNING: Will sort the receipts by timestamp using
Expand All @@ -130,18 +143,19 @@ pub fn safe_truncate_receipts(receipts: &mut Vec<(u64, ReceivedReceipt)>, limit:
return;
}

receipts.sort_unstable_by_key(|(_, rx_receipt)| rx_receipt.signed_receipt.message.timestamp_ns);
receipts
.sort_unstable_by_key(|(_, rx_receipt)| rx_receipt.signed_receipt().message.timestamp_ns);

// This one will be the last timestamp in `receipts` after naive truncation
let last_timestamp = receipts[limit as usize - 1]
.1
.signed_receipt
.signed_receipt()
.message
.timestamp_ns;
// This one is the timestamp that comes just after the one above
let after_last_timestamp = receipts[limit as usize]
.1
.signed_receipt
.signed_receipt()
.message
.timestamp_ns;

Expand All @@ -152,7 +166,7 @@ pub fn safe_truncate_receipts(receipts: &mut Vec<(u64, ReceivedReceipt)>, limit:
// remove all the receipts with the same timestamp as the last one, because
// otherwise we would leave behind part of the receipts for that timestamp.
receipts.retain(|(_, rx_receipt)| {
rx_receipt.signed_receipt.message.timestamp_ns != last_timestamp
rx_receipt.signed_receipt().message.timestamp_ns != last_timestamp
});
}
}
8 changes: 4 additions & 4 deletions tap_core/src/adapters/test/receipt_checks_adapter_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ mod receipt_checks_adapter_unit_test {
.insert(unique_receipt_id, new_receipt.1.clone());

assert!(receipt_checks_adapter
.is_unique(&new_receipt.1.signed_receipt, unique_receipt_id)
.is_unique(new_receipt.1.signed_receipt(), unique_receipt_id)
.await
.unwrap());
assert!(receipt_checks_adapter
.is_valid_allocation_id(new_receipt.1.signed_receipt.message.allocation_id)
.is_valid_allocation_id(new_receipt.1.signed_receipt().message.allocation_id)
.await
.unwrap());
// TODO: Add check when sender_id is available from received receipt (issue: #56)
// assert!(receipt_checks_adapter.is_valid_sender_id(sender_id));
assert!(receipt_checks_adapter
.is_valid_value(
new_receipt.1.signed_receipt.message.value,
new_receipt.1.query_id
new_receipt.1.signed_receipt().message.value,
new_receipt.1.query_id()
)
.await
.unwrap());
Expand Down
9 changes: 5 additions & 4 deletions tap_core/src/adapters/test/receipt_storage_adapter_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ mod receipt_storage_adapter_unit_test {
receipt_storage_adapter::ReceiptStore,
receipt_storage_adapter_mock::ReceiptStorageAdapterMock,
};
use crate::tap_receipt::ReceivedReceipt;
use crate::{
eip_712_signed_message::EIP712SignedMessage, tap_receipt::get_full_list_of_checks,
tap_receipt::Receipt, tap_receipt::ReceivedReceipt,
tap_receipt::Receipt,
};

#[fixture]
Expand Down Expand Up @@ -135,7 +136,7 @@ mod receipt_storage_adapter_unit_test {
.await
.unwrap(),
);
receipt_timestamps.push(received_receipt.signed_receipt.message.timestamp_ns)
receipt_timestamps.push(received_receipt.signed_receipt().message.timestamp_ns)
}

// Retreive receipts with timestamp
Expand Down Expand Up @@ -241,12 +242,12 @@ mod receipt_storage_adapter_unit_test {
for (elem_trun, expected_timestamp) in receipts_truncated.iter().zip(expected.iter()) {
// Check timestamps
assert_eq!(
elem_trun.1.signed_receipt.message.timestamp_ns,
elem_trun.1.signed_receipt().message.timestamp_ns,
*expected_timestamp
);

// Check that the IDs are fine
assert_eq!(elem_trun.0, elem_trun.1.query_id);
assert_eq!(elem_trun.0, elem_trun.1.query_id());
}
}
}
83 changes: 57 additions & 26 deletions tap_core/src/tap_manager/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use crate::{
receipt_storage_adapter::{ReceiptRead, ReceiptStore},
},
receipt_aggregate_voucher::ReceiptAggregateVoucher,
tap_receipt::{ReceiptAuditor, ReceiptCheck, ReceivedReceipt},
tap_receipt::{
CategorizedReceiptsWithState, Failed, ReceiptAuditor, ReceiptCheck, ReceiptWithId,
ReceiptWithState, ReceivedReceipt, Reserved,
},
Error,
};

Expand Down Expand Up @@ -122,7 +125,13 @@ where
timestamp_buffer_ns: u64,
min_timestamp_ns: u64,
limit: Option<u64>,
) -> Result<(Vec<SignedReceipt>, Vec<ReceivedReceipt>), Error> {
) -> Result<
(
Vec<ReceiptWithState<Reserved>>,
Vec<ReceiptWithState<Failed>>,
),
Error,
> {
let max_timestamp_ns = crate::get_current_timestamp_u64_ns()? - timestamp_buffer_ns;

if min_timestamp_ns > max_timestamp_ns {
Expand All @@ -139,30 +148,38 @@ where
source_error: anyhow::Error::new(err),
})?;

let mut accepted_signed_receipts = Vec::<SignedReceipt>::new();
let mut failed_signed_receipts = Vec::<ReceivedReceipt>::new();
let CategorizedReceiptsWithState {
checking_receipts,
mut awaiting_reserve_receipts,
mut failed_receipts,
mut reserved_receipts,
} = received_receipts.into();

let mut received_receipts: Vec<ReceivedReceipt> =
received_receipts.into_iter().map(|e| e.1).collect();
for received_receipt in checking_receipts {
let ReceiptWithId {
receipt,
receipt_id,
} = received_receipt;
let receipt = receipt
.finalize_receipt_checks(receipt_id, &self.receipt_auditor)
.await;

for check in self.required_checks.iter() {
ReceivedReceipt::perform_check_batch(
&mut received_receipts,
check,
&self.receipt_auditor,
)
.await?;
match receipt {
Ok(checked) => awaiting_reserve_receipts.push(checked),
Err(failed) => failed_receipts.push(failed),
}
}

for received_receipt in received_receipts {
if received_receipt.is_accepted() {
accepted_signed_receipts.push(received_receipt.signed_receipt);
} else {
failed_signed_receipts.push(received_receipt);
for checked in awaiting_reserve_receipts {
match checked
.check_and_reserve_escrow(&self.receipt_auditor)
.await
{
Ok(reserved) => reserved_receipts.push(reserved),
Err(failed) => failed_receipts.push(failed),
}
}

Ok((accepted_signed_receipts, failed_signed_receipts))
Ok((reserved_receipts, failed_receipts))
}
}

Expand Down Expand Up @@ -203,6 +220,10 @@ where
self.receipt_auditor
.update_min_timestamp_ns(expected_rav.timestamp_ns)
.await;
let valid_receipts = valid_receipts
.into_iter()
.map(|rx_receipt| rx_receipt.signed_receipt)
.collect::<Vec<_>>();

Ok(RAVRequest {
valid_receipts,
Expand All @@ -213,14 +234,22 @@ where
}

fn generate_expected_rav(
receipts: &[SignedReceipt],
receipts: &[ReceiptWithState<Reserved>],
previous_rav: Option<SignedRAV>,
) -> Result<ReceiptAggregateVoucher, Error> {
if receipts.is_empty() {
return Err(Error::NoValidReceiptsForRAVRequest);
}
let allocation_id = receipts[0].message.allocation_id;
ReceiptAggregateVoucher::aggregate_receipts(allocation_id, receipts, previous_rav)
let allocation_id = receipts[0].signed_receipt().message.allocation_id;
let receipts = receipts
.iter()
.map(|rx_receipt| rx_receipt.signed_receipt().clone())
.collect::<Vec<_>>();
ReceiptAggregateVoucher::aggregate_receipts(
allocation_id,
receipts.as_slice(),
previous_rav,
)
}
}

Expand Down Expand Up @@ -291,9 +320,11 @@ where
source_error: anyhow::Error::new(err),
})?;

received_receipt
.perform_checks(initial_checks, receipt_id, &self.receipt_auditor)
.await?;
if let ReceivedReceipt::Checking(received_receipt) = &mut received_receipt {
received_receipt
.perform_checks(initial_checks, receipt_id, &self.receipt_auditor)
.await;
}

self.receipt_storage_adapter
.update_receipt_by_id(receipt_id, received_receipt)
Expand Down
7 changes: 5 additions & 2 deletions tap_core/src/tap_manager/rav_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
use serde::{Deserialize, Serialize};

use super::{SignedRAV, SignedReceipt};
use crate::{receipt_aggregate_voucher::ReceiptAggregateVoucher, tap_receipt::ReceivedReceipt};
use crate::{
receipt_aggregate_voucher::ReceiptAggregateVoucher,
tap_receipt::{Failed, ReceiptWithState},
};

#[derive(Debug, Serialize, Deserialize, Clone)]

pub struct RAVRequest {
pub valid_receipts: Vec<SignedReceipt>,
pub previous_rav: Option<SignedRAV>,
pub invalid_receipts: Vec<ReceivedReceipt>,
pub invalid_receipts: Vec<ReceiptWithState<Failed>>,
pub expected_rav: ReceiptAggregateVoucher,
}
9 changes: 5 additions & 4 deletions tap_core/src/tap_receipt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ use std::collections::HashMap;
use alloy_primitives::Address;
pub use receipt::Receipt;
pub use receipt_auditor::ReceiptAuditor;
pub use received_receipt::{RAVStatus, ReceiptState, ReceivedReceipt};
pub use received_receipt::{
AwaitingReserve, CategorizedReceiptsWithState, Checking, Failed, ReceiptState, ReceiptWithId,
ReceiptWithState, ReceivedReceipt, Reserved, ResultReceipt,
};

use serde::{Deserialize, Serialize};
use strum_macros::{Display, EnumString};
use thiserror::Error;
Expand Down Expand Up @@ -44,7 +48,6 @@ pub enum ReceiptCheck {
CheckTimestamp,
CheckValue,
CheckSignature,
CheckAndReserveEscrow,
}

pub fn get_full_list_of_receipt_check_results() -> ReceiptCheckResults {
Expand All @@ -54,7 +57,6 @@ pub fn get_full_list_of_receipt_check_results() -> ReceiptCheckResults {
all_checks_list.insert(ReceiptCheck::CheckTimestamp, None);
all_checks_list.insert(ReceiptCheck::CheckValue, None);
all_checks_list.insert(ReceiptCheck::CheckSignature, None);
all_checks_list.insert(ReceiptCheck::CheckAndReserveEscrow, None);

all_checks_list
}
Expand All @@ -66,7 +68,6 @@ pub fn get_full_list_of_checks() -> Vec<ReceiptCheck> {
ReceiptCheck::CheckTimestamp,
ReceiptCheck::CheckValue,
ReceiptCheck::CheckSignature,
ReceiptCheck::CheckAndReserveEscrow,
]
}

Expand Down
Loading

0 comments on commit d841a44

Please sign in to comment.