Skip to content

Commit

Permalink
refactor: use verify_signer from escrow_adapter
Browse files Browse the repository at this point in the history
Signed-off-by: Gustavo Inacio <[email protected]>
  • Loading branch information
gusinacio committed Mar 8, 2024
1 parent cf4e262 commit 637c24c
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 82 deletions.
1 change: 0 additions & 1 deletion tap_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ strum = "0.24.1"
strum_macros = "0.24.3"
async-trait = "0.1.72"
tokio = { version = "1.29.1", features = ["macros", "rt-multi-thread"] }
futures = "0.3.17"

[dev-dependencies]
criterion = { version = "0.5", features = ["async_std"] }
Expand Down
5 changes: 5 additions & 0 deletions tap_core/src/adapters/escrow_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,9 @@ pub trait EscrowAdapter {
sender_id: Address,
value: u128,
) -> Result<(), Self::AdapterError>;

async fn verify_signer(
&self,
signer_address: Address
) -> Result<bool, Self::AdapterError>;
}
16 changes: 14 additions & 2 deletions tap_core/src/adapters/mock/executor_mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ pub struct ExecutorMock {
rav_storage: RAVStorage,
receipt_storage: ReceiptStorage,
unique_id: Arc<RwLock<u64>>,

sender_escrow_storage: EscrowStorage,

timestamp_check: Arc<TimestampCheck>,
sender_address: Option<Address>,
}

impl ExecutorMock {
Expand All @@ -56,9 +55,15 @@ impl ExecutorMock {
unique_id: Arc::new(RwLock::new(0)),
sender_escrow_storage,
timestamp_check,
sender_address: None,
}
}

pub fn with_sender_address(mut self, sender_address: Address) -> Self {
self.sender_address = Some(sender_address);
self
}

pub async fn retrieve_receipt_by_id(
&self,
receipt_id: u64,
Expand Down Expand Up @@ -241,4 +246,11 @@ impl EscrowAdapter for ExecutorMock {
) -> Result<(), Self::AdapterError> {
self.reduce_escrow(sender_id, value)
}

async fn verify_signer(&self, signer_address: Address) -> Result<bool, Self::AdapterError> {
Ok(self
.sender_address
.map(|sender| signer_address == sender)
.unwrap_or(false))
}
}
16 changes: 4 additions & 12 deletions tap_core/src/tap_manager/manager.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
// Copyright 2023-, Semiotic AI, Inc.
// SPDX-License-Identifier: Apache-2.0

use alloy_primitives::Address;
use alloy_sol_types::Eip712Domain;
use futures::Future;

use super::{RAVRequest, SignedRAV, SignedReceipt};
use crate::{
Expand Down Expand Up @@ -59,27 +57,21 @@ where

impl<E> Manager<E>
where
E: RAVStore,
E: RAVStore + EscrowAdapter,
{
/// Verify `signed_rav` matches all values on `expected_rav`, and that `signed_rav` has a valid signer.
///
/// # Errors
///
/// Returns [`Error::AdapterError`] if there are any errors while storing RAV
///
pub async fn verify_and_store_rav<F, Fut, Err>(
pub async fn verify_and_store_rav(
&self,
expected_rav: ReceiptAggregateVoucher,
signed_rav: SignedRAV,
verify_signer: F,
) -> std::result::Result<(), Error>
where
F: FnOnce(Address) -> Fut,
Fut: Future<Output = Result<bool, Err>>,
Err: std::fmt::Display,
{
) -> std::result::Result<(), Error> {
self.receipt_auditor
.check_rav_signature(&signed_rav, verify_signer)
.check_rav_signature(&signed_rav)
.await?;

if signed_rav.message != expected_rav {
Expand Down
34 changes: 8 additions & 26 deletions tap_core/src/tap_manager/test/manager_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ fn executor_mock(
domain_separator: Eip712Domain,
allocation_ids: Vec<Address>,
sender_ids: Vec<Address>,
keys: (LocalWallet, Address),
) -> ExecutorFixture {
let escrow_storage = Arc::new(RwLock::new(HashMap::new()));
let rav_storage = Arc::new(RwLock::new(None));
Expand All @@ -84,7 +85,8 @@ fn executor_mock(
receipt_storage.clone(),
escrow_storage.clone(),
timestamp_check.clone(),
);
)
.with_sender_address(keys.1);

let mut checks = get_full_list_of_checks(
domain_separator,
Expand Down Expand Up @@ -189,11 +191,7 @@ async fn manager_create_rav_request_all_valid_receipts(
EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0)
.unwrap();
assert!(manager
.verify_and_store_rav(
rav_request.expected_rav,
signed_rav,
|address: Address| async move { Ok::<bool, String>(keys.1 == address) }
)
.verify_and_store_rav(rav_request.expected_rav, signed_rav)
.await
.is_ok());
}
Expand Down Expand Up @@ -260,11 +258,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0)
.unwrap();
assert!(manager
.verify_and_store_rav(
rav_request.expected_rav,
signed_rav,
|address: Address| async move { Ok::<bool, String>(keys.1 == address) }
)
.verify_and_store_rav(rav_request.expected_rav, signed_rav)
.await
.is_ok());

Expand Down Expand Up @@ -310,11 +304,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts(
EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0)
.unwrap();
assert!(manager
.verify_and_store_rav(
rav_request.expected_rav,
signed_rav,
|address: Address| async move { Ok::<bool, String>(keys.1 == address) }
)
.verify_and_store_rav(rav_request.expected_rav, signed_rav)
.await
.is_ok());
}
Expand Down Expand Up @@ -391,11 +381,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
)
.unwrap();
assert!(manager
.verify_and_store_rav(
rav_request_1.expected_rav,
signed_rav_1,
|address: Address| async move { Ok::<bool, String>(keys.1 == address) }
)
.verify_and_store_rav(rav_request_1.expected_rav, signed_rav_1)
.await
.is_ok());

Expand Down Expand Up @@ -456,11 +442,7 @@ async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_tim
)
.unwrap();
assert!(manager
.verify_and_store_rav(
rav_request_2.expected_rav,
signed_rav_2,
|address: Address| async move { Ok::<bool, String>(keys.1 == address) }
)
.verify_and_store_rav(rav_request_2.expected_rav, signed_rav_2)
.await
.is_ok());
}
41 changes: 16 additions & 25 deletions tap_core/src/tap_receipt/receipt_auditor.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
// Copyright 2023-, Semiotic AI, Inc.
// SPDX-License-Identifier: Apache-2.0

use alloy_primitives::Address;
use alloy_sol_types::Eip712Domain;
use futures::Future;

use crate::{
adapters::escrow_adapter::EscrowAdapter,
Expand All @@ -26,29 +24,6 @@ impl<E> ReceiptAuditor<E> {
executor,
}
}

pub async fn check_rav_signature<F, Fut, Err>(
&self,
signed_rav: &SignedRAV,
verify_signer: F,
) -> Result<(), Error>
where
F: FnOnce(Address) -> Fut,
Fut: Future<Output = Result<bool, Err>>,
Err: std::fmt::Display,
{
let recovered_address = signed_rav.recover_signer(&self.domain_separator)?;
if verify_signer(recovered_address)
.await
.map_err(|e| Error::FailedToVerifySigner(e.to_string()))?
{
Ok(())
} else {
Err(Error::InvalidRecoveredSigner {
address: recovered_address,
})
}
}
}

impl<E> ReceiptAuditor<E>
Expand Down Expand Up @@ -77,4 +52,20 @@ where

Ok(())
}

pub async fn check_rav_signature(&self, signed_rav: &SignedRAV) -> Result<(), Error> {
let recovered_address = signed_rav.recover_signer(&self.domain_separator)?;
if self
.executor
.verify_signer(recovered_address)
.await
.map_err(|e| Error::FailedToVerifySigner(e.to_string()))?
{
Ok(())
} else {
Err(Error::InvalidRecoveredSigner {
address: recovered_address,
})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ fn executor_mock(
domain_separator: Eip712Domain,
allocation_ids: Vec<Address>,
sender_ids: Vec<Address>,
keys: (LocalWallet, Address),
) -> ExecutorFixture {
let escrow_storage = Arc::new(RwLock::new(HashMap::new()));
let rav_storage = Arc::new(RwLock::new(None));
Expand All @@ -82,7 +83,8 @@ fn executor_mock(
receipt_storage.clone(),
escrow_storage.clone(),
timestamp_check.clone(),
);
)
.with_sender_address(keys.1);
let mut checks = get_full_list_of_checks(
domain_separator,
sender_ids.iter().cloned().collect(),
Expand Down
14 changes: 1 addition & 13 deletions tap_integration_tests/tests/indexer_mock/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::sync::{
Arc,
};

use alloy_primitives::Address;
use alloy_sol_types::Eip712Domain;
use anyhow::{Error, Result};
use jsonrpsee::{
Expand Down Expand Up @@ -51,7 +50,6 @@ pub struct RpcManager<E> {
receipt_count: Arc<AtomicU64>, // Thread-safe atomic counter for receipts
threshold: u64, // The count at which a RAV request will be triggered
aggregator_client: (HttpClient, String), // HTTP client for sending requests to the aggregator server
sender_id: Address, // The sender address
}

/// Implementation for `RpcManager`, includes the constructor and the `request` method.
Expand All @@ -66,7 +64,6 @@ where
executor: E,
required_checks: Checks,
threshold: u64,
sender_id: Address,
aggregate_server_address: String,
aggregate_server_api_version: String,
) -> Result<Self> {
Expand All @@ -78,7 +75,6 @@ where
)),
receipt_count: Arc::new(AtomicU64::new(0)),
threshold,
sender_id,
aggregator_client: (
HttpClientBuilder::default().build(aggregate_server_address)?,
aggregate_server_api_version,
Expand Down Expand Up @@ -118,7 +114,6 @@ where
time_stamp_buffer,
&self.aggregator_client,
self.threshold as usize,
self.sender_id,
)
.await
{
Expand Down Expand Up @@ -146,7 +141,6 @@ pub async fn run_server<E>(
threshold: u64, // The count at which a RAV request will be triggered
aggregate_server_address: String, // Address of the aggregator server
aggregate_server_api_version: String, // API version of the aggregator server
sender_id: Address, // The sender address
) -> Result<(ServerHandle, std::net::SocketAddr)>
where
E: ReceiptStore
Expand All @@ -172,7 +166,6 @@ where
executor,
required_checks,
threshold,
sender_id,
aggregate_server_address,
aggregate_server_api_version,
)?;
Expand All @@ -187,7 +180,6 @@ async fn request_rav<E>(
time_stamp_buffer: u64, // Buffer for timestamping, see tap_core for details
aggregator_client: &(HttpClient, String), // HttpClient for making requests to the tap_aggregator server
threshold: usize,
expected_sender_id: Address,
) -> Result<()>
where
E: ReceiptRead + RAVRead + RAVStore + EscrowAdapter,
Expand All @@ -208,11 +200,7 @@ where
.request("aggregate_receipts", params)
.await?;
manager
.verify_and_store_rav(
rav_request.expected_rav,
remote_rav_result.data,
|address| async move { Ok::<bool, String>(address == expected_sender_id) },
)
.verify_and_store_rav(rav_request.expected_rav, remote_rav_result.data)
.await?;

// For these tests, we expect every receipt to be valid, i.e. there should be no invalid receipts, nor any missing receipts (less than the expected threshold).
Expand Down
3 changes: 1 addition & 2 deletions tap_integration_tests/tests/showcase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -871,12 +871,11 @@ async fn start_indexer_server(
let (server_handle, socket_addr) = indexer_mock::run_server(
http_port,
domain_separator,
executor,
executor.with_sender_address(sender_id),
required_checks,
receipt_threshold,
aggregate_server_address,
aggregate_server_api_version(),
sender_id,
)
.await?;

Expand Down

0 comments on commit 637c24c

Please sign in to comment.