Skip to content

Commit

Permalink
tests: add mock checks
Browse files Browse the repository at this point in the history
Signed-off-by: Gustavo Inacio <[email protected]>
  • Loading branch information
gusinacio committed Mar 4, 2024
1 parent 6f5c821 commit 61e5137
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 123 deletions.
24 changes: 20 additions & 4 deletions tap_core/src/adapters/test/receipt_checks_adapter_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,34 @@ mod receipt_checks_adapter_unit_test {
use tokio::sync::RwLock;

use crate::{
checks::{tests::get_full_list_of_checks, ReceiptCheck},
eip_712_signed_message::EIP712SignedMessage,
tap_eip712_domain,
tap_receipt::{get_full_list_of_checks, Receipt, ReceivedReceipt},
tap_receipt::{Receipt, ReceivedReceipt},
};

#[fixture]
fn domain_separator() -> Eip712Domain {
tap_eip712_domain(1, Address::from([0x11u8; 20]))
}

#[fixture]
fn checks(domain_separator: Eip712Domain) -> Vec<ReceiptCheck> {
get_full_list_of_checks(
domain_separator,
HashSet::new(),
Arc::new(RwLock::new(HashSet::new())),
Arc::new(RwLock::new(HashMap::new())),
Arc::new(RwLock::new(HashMap::new())),
)
}

#[rstest]
#[tokio::test]
async fn receipt_checks_adapter_test(domain_separator: Eip712Domain) {
async fn receipt_checks_adapter_test(
domain_separator: Eip712Domain,
checks: Vec<ReceiptCheck>,
) {
let sender_ids = [
Address::from_str("0xfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfb").unwrap(),
Address::from_str("0xfafafafafafafafafafafafafafafafafafafafa").unwrap(),
Expand All @@ -54,6 +69,7 @@ mod receipt_checks_adapter_unit_test {
.then(|id| {
let wallet = wallet.clone();
let domain_separator = &domain_separator;
let checks = checks.clone();
async move {
(
id,
Expand All @@ -65,7 +81,7 @@ mod receipt_checks_adapter_unit_test {
)
.unwrap(),
id,
&get_full_list_of_checks(),
&checks,
),
)
}
Expand Down Expand Up @@ -95,7 +111,7 @@ mod receipt_checks_adapter_unit_test {
)
.unwrap(),
10u64,
&get_full_list_of_checks(),
&checks,
),
);

Expand Down
31 changes: 21 additions & 10 deletions tap_core/src/adapters/test/receipt_storage_adapter_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
mod receipt_storage_adapter_unit_test {
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;

Expand All @@ -20,21 +20,31 @@ mod receipt_storage_adapter_unit_test {
receipt_storage_adapter::ReceiptStore,
receipt_storage_adapter_mock::ReceiptStorageAdapterMock,
};
use crate::checks::tests::get_full_list_of_checks;
use crate::checks::ReceiptCheck;
use crate::tap_eip712_domain;
use crate::tap_receipt::ReceivedReceipt;
use crate::{
eip_712_signed_message::EIP712SignedMessage, tap_receipt::get_full_list_of_checks,
tap_receipt::Receipt,
};
use crate::{eip_712_signed_message::EIP712SignedMessage, tap_receipt::Receipt};

#[fixture]
fn domain_separator() -> Eip712Domain {
tap_eip712_domain(1, Address::from([0x11u8; 20]))
}

#[fixture]
fn checks(domain_separator: Eip712Domain) -> Vec<ReceiptCheck> {
get_full_list_of_checks(
domain_separator,
HashSet::new(),
Arc::new(RwLock::new(HashSet::new())),
Arc::new(RwLock::new(HashMap::new())),
Arc::new(RwLock::new(HashMap::new())),
)
}

#[rstest]
#[tokio::test]
async fn receipt_adapter_test(domain_separator: Eip712Domain) {
async fn receipt_adapter_test(domain_separator: Eip712Domain, checks: Vec<ReceiptCheck>) {
let receipt_storage = Arc::new(RwLock::new(HashMap::new()));
let mut receipt_adapter = ReceiptStorageAdapterMock::new(receipt_storage);

Expand All @@ -57,7 +67,7 @@ mod receipt_storage_adapter_unit_test {
)
.unwrap(),
query_id,
&get_full_list_of_checks(),
&checks,
);

let receipt_store_result = receipt_adapter.store_receipt(received_receipt).await;
Expand Down Expand Up @@ -95,7 +105,7 @@ mod receipt_storage_adapter_unit_test {

#[rstest]
#[tokio::test]
async fn multi_receipt_adapter_test(domain_separator: Eip712Domain) {
async fn multi_receipt_adapter_test(domain_separator: Eip712Domain, checks: Vec<ReceiptCheck>) {
let receipt_storage = Arc::new(RwLock::new(HashMap::new()));
let mut receipt_adapter = ReceiptStorageAdapterMock::new(receipt_storage);

Expand All @@ -118,7 +128,7 @@ mod receipt_storage_adapter_unit_test {
)
.unwrap(),
query_id as u64,
&get_full_list_of_checks(),
&checks,
));
}
let mut receipt_ids = Vec::new();
Expand Down Expand Up @@ -186,6 +196,7 @@ mod receipt_storage_adapter_unit_test {
#[test]
fn safe_truncate_receipts_test(
domain_separator: Eip712Domain,
checks: Vec<ReceiptCheck>,
#[case] input: Vec<u64>,
#[case] limit: u64,
#[case] expected: Vec<u64>,
Expand Down Expand Up @@ -215,7 +226,7 @@ mod receipt_storage_adapter_unit_test {
)
.unwrap(),
i as u64, // Will use that to check the IDs
&get_full_list_of_checks(),
&checks,
),
));
}
Expand Down
195 changes: 156 additions & 39 deletions tap_core/src/checks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::tap_receipt::{Checking, ReceiptError, ReceiptResult, ReceiptWithState};
use ethers::types::Signature;
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, sync::Arc};
use std::sync::Arc;
use tokio::sync::RwLock;

pub type ReceiptCheck = Arc<dyn Check>;
Expand All @@ -21,8 +20,6 @@ impl CheckingChecks {
match self {
Self::Pending(check) => {
let result = check.check(&receipt).await;
// *self = Self::Executed(result);
// self
Self::Executed(result)
}
Self::Executed(_) => self,
Expand Down Expand Up @@ -57,33 +54,6 @@ pub trait Check: std::fmt::Debug + Send + Sync {
}
}

#[derive(Debug, Serialize, Deserialize)]
struct UniqueCheck;

#[async_trait::async_trait]
#[typetag::serde]
impl Check for UniqueCheck {
async fn check(&self, _receipt: &ReceiptWithState<Checking>) -> ReceiptResult<()> {
println!("UniqueCheck");
Ok(())
}

async fn check_batch(&self, receipts: &[ReceiptWithState<Checking>]) -> Vec<ReceiptResult<()>> {
let mut signatures: HashSet<Signature> = HashSet::new();
let mut results = Vec::new();

for received_receipt in receipts {
let signature = received_receipt.signed_receipt.signature;
if signatures.insert(signature) {
results.push(Ok(()));
} else {
results.push(Err(ReceiptError::NonUniqueReceipt));
}
}
results
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct TimestampCheck {
#[serde(skip)]
Expand Down Expand Up @@ -118,14 +88,161 @@ impl Check for TimestampCheck {
}
}

#[derive(Debug, Serialize, Deserialize)]
struct AllocationId;
#[cfg(test)]
pub mod tests {

use super::*;
use crate::tap_receipt::ReceivedReceipt;
use alloy_primitives::Address;
use alloy_sol_types::Eip712Domain;
use std::{
collections::{HashMap, HashSet},
fmt::Debug,
};

pub fn get_full_list_of_checks(
domain_separator: Eip712Domain,
valid_signers: HashSet<Address>,
allocation_ids: Arc<RwLock<HashSet<Address>>>,
receipt_storage: Arc<RwLock<HashMap<u64, ReceivedReceipt>>>,
query_appraisals: Arc<RwLock<HashMap<u64, u128>>>,
) -> Vec<ReceiptCheck> {
vec![
Arc::new(UniqueCheck { receipt_storage }),
Arc::new(ValueCheck { query_appraisals }),
Arc::new(AllocationIdCheck { allocation_ids }),
Arc::new(SignatureCheck {
domain_separator,
valid_signers,
}),
Arc::new(TimestampCheck::new(0)),
]
}

#[async_trait::async_trait]
#[typetag::serde]
impl Check for AllocationId {
async fn check(&self, _receipt: &ReceiptWithState<Checking>) -> ReceiptResult<()> {
println!("AllocationId");
Ok(())
#[derive(Serialize, Deserialize)]
struct UniqueCheck {
#[serde(skip)]
receipt_storage: Arc<RwLock<HashMap<u64, ReceivedReceipt>>>,
}
impl Debug for UniqueCheck {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "UniqueCheck")
}
}

#[async_trait::async_trait]
#[typetag::serde]
impl Check for UniqueCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> ReceiptResult<()> {
let receipt_storage = self.receipt_storage.read().await;
// let receipt_id = receipt.
receipt_storage
.iter()
.all(|(_stored_receipt_id, stored_receipt)| {
stored_receipt.signed_receipt().message != receipt.signed_receipt().message
})
.then_some(())
.ok_or(ReceiptError::NonUniqueReceipt)
}

async fn check_batch(
&self,
receipts: &[ReceiptWithState<Checking>],
) -> Vec<ReceiptResult<()>> {
let mut signatures: HashSet<ethers::types::Signature> = HashSet::new();
let mut results = Vec::new();

for received_receipt in receipts {
let signature = received_receipt.signed_receipt.signature;
if signatures.insert(signature) {
results.push(Ok(()));
} else {
results.push(Err(ReceiptError::NonUniqueReceipt));
}
}
results
}
}

#[derive(Debug, Serialize, Deserialize)]
struct ValueCheck {
#[serde(skip)]
query_appraisals: Arc<RwLock<HashMap<u64, u128>>>,
}

#[async_trait::async_trait]
#[typetag::serde]
impl Check for ValueCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> ReceiptResult<()> {
let query_id = receipt.query_id;
let value = receipt.signed_receipt().message.value;
let query_appraisals = self.query_appraisals.read().await;
let appraised_value =
query_appraisals
.get(&query_id)
.ok_or(ReceiptError::CheckFailedToComplete {
source_error_message: "Could not find query_appraisals".into(),
})?;

if value != *appraised_value {
Err(ReceiptError::InvalidValue {
received_value: value,
})
} else {
Ok(())
}
}
}

#[derive(Debug, Serialize, Deserialize)]
struct AllocationIdCheck {
#[serde(skip)]
allocation_ids: Arc<RwLock<HashSet<Address>>>,
}

#[async_trait::async_trait]
#[typetag::serde]
impl Check for AllocationIdCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> ReceiptResult<()> {
let received_allocation_id = receipt.signed_receipt().message.allocation_id;
if self
.allocation_ids
.read()
.await
.contains(&received_allocation_id)
{
Ok(())
} else {
Err(ReceiptError::InvalidAllocationID {
received_allocation_id,
})
}
}
}

#[derive(Debug, Serialize, Deserialize)]
struct SignatureCheck {
domain_separator: Eip712Domain,
valid_signers: HashSet<Address>,
}

#[async_trait::async_trait]
#[typetag::serde]
impl Check for SignatureCheck {
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> ReceiptResult<()> {
let recovered_address = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
.map_err(|e| ReceiptError::InvalidSignature {
source_error_message: e.to_string(),
})?;
if !self.valid_signers.contains(&recovered_address) {
Err(ReceiptError::InvalidSignature {
source_error_message: "Invalid signer".to_string(),
})
} else {
Ok(())
}
}
}
}
Loading

0 comments on commit 61e5137

Please sign in to comment.