From 7a7126c89d2f3957bdeb9c53edaec2c59d84b5a4 Mon Sep 17 00:00:00 2001 From: Stoyan Kirov Date: Fri, 10 May 2024 13:58:36 +0300 Subject: [PATCH] add mvp verification logic --- ampd/src/handlers/errors.rs | 2 + ampd/src/handlers/starknet_verify_msg.rs | 76 +++++++------- ampd/src/lib.rs | 32 +++--- ampd/src/starknet/json_rpc.rs | 75 +++++++------- ampd/src/starknet/types/array_span.rs | 2 +- ampd/src/starknet/verifier.rs | 120 +++++++++++------------ 6 files changed, 164 insertions(+), 143 deletions(-) diff --git a/ampd/src/handlers/errors.rs b/ampd/src/handlers/errors.rs index b926714f8..249f6578f 100644 --- a/ampd/src/handlers/errors.rs +++ b/ampd/src/handlers/errors.rs @@ -12,4 +12,6 @@ pub enum Error { Sign, #[error("failed to get transaction receipts")] TxReceipts, + #[error("failed to verify message against event")] + Verification, } diff --git a/ampd/src/handlers/starknet_verify_msg.rs b/ampd/src/handlers/starknet_verify_msg.rs index 355de0771..7dd3309b5 100644 --- a/ampd/src/handlers/starknet_verify_msg.rs +++ b/ampd/src/handlers/starknet_verify_msg.rs @@ -1,13 +1,13 @@ -use std::collections::HashSet; +use std::collections::HashMap; use std::convert::TryInto; use async_trait::async_trait; use axelar_wasm_std::voting::{PollId, Vote}; -// use connection_router_api::ChainName; use cosmrs::cosmwasm::MsgExecuteContract; +use error_stack::{FutureExt, ResultExt}; use events::Error::EventTypeMismatch; use events_derive::try_from; -use futures::future::join_all; +use futures::future::try_join_all; use serde::Deserialize; use tokio::sync::watch::Receiver; use tracing::info; @@ -17,7 +17,9 @@ use crate::event_processor::EventHandler; use crate::handlers::errors::Error; use crate::handlers::errors::Error::DeserializeEvent; use crate::queue::queued_broadcaster::BroadcasterClient; -use crate::starknet::verifier::MessageVerifier; +use crate::starknet::events::contract_call::ContractCallEvent; +use crate::starknet::json_rpc::StarknetClient; +use crate::starknet::verifier::verify_msg; use crate::types::{Hash, TMAddress}; type Result = error_stack::Result; @@ -38,42 +40,40 @@ struct PollStartedEvent { #[serde(rename = "_contract_address")] contract_address: TMAddress, poll_id: PollId, - // source_chain: ChainName, - // source_gateway_address: String, - // confirmation_height: u64, + source_gateway_address: String, expires_at: u64, messages: Vec, participants: Vec, } -pub struct Handler +pub struct Handler where - V: MessageVerifier, + C: StarknetClient, B: BroadcasterClient, { worker: TMAddress, voting_verifier: TMAddress, - msg_verifier: V, + rpc_client: C, broadcast_client: B, latest_block_height: Receiver, } -impl Handler +impl Handler where - V: MessageVerifier + Send + Sync, + C: StarknetClient + Send + Sync, B: BroadcasterClient, { pub fn new( worker: TMAddress, voting_verifier: TMAddress, - msg_verifier: V, + rpc_client: C, broadcast_client: B, latest_block_height: Receiver, ) -> Self { Self { worker, voting_verifier, - msg_verifier, + rpc_client, broadcast_client, latest_block_height, } @@ -99,21 +99,20 @@ where #[async_trait] impl EventHandler for Handler where - V: MessageVerifier + Send + Sync, + V: StarknetClient + Send + Sync, B: BroadcasterClient + Send + Sync, { type Err = Error; async fn handle(&self, event: &events::Event) -> Result<()> { let PollStartedEvent { - contract_address, poll_id, - // source_chain, - // source_gateway_address: _, - // confirmation_height: _, + source_gateway_address, messages, - expires_at, participants, + expires_at, + contract_address, + .. } = match event.try_into() as error_stack::Result<_, _> { Err(report) if matches!(report.current_context(), EventTypeMismatch(_)) => { return Ok(()); @@ -135,27 +134,30 @@ where return Ok(()); } - let tx_hashes: HashSet<_> = messages - .iter() - .map(|message| message.tx_id.as_str()) - .collect(); - - let unique_axl_msgs: Vec<&Message> = messages - .iter() - .filter(|m| tx_hashes.get(m.tx_id.as_str()).is_some()) - .collect(); - - let votes: Vec = join_all( - unique_axl_msgs - .into_iter() - .map(|msg| self.msg_verifier.verify_msg(msg)), + let events: HashMap = try_join_all( + messages + .iter() + .map(|msg| self.rpc_client.get_event_by_hash(msg.tx_id.as_str())), ) - .await + .change_context(Error::TxReceipts) + .await? .into_iter() - // TODO: Maybe log the errors (mostly with connection/serialization)? - .filter_map(|v| v.ok()) + .flatten() .collect(); + let mut votes = vec![]; + for msg in messages { + if !events.contains_key(&msg.tx_id) { + votes.push(Vote::NotFound); + continue; + } + votes.push(verify_msg( + events.get(&msg.tx_id).unwrap(), // safe to unwrap, because of previous check + &msg, + &source_gateway_address, + )); + } + println!("VOTES {:?}", votes); self.broadcast_votes(poll_id, votes).await diff --git a/ampd/src/lib.rs b/ampd/src/lib.rs index 7575bc11b..29541527d 100644 --- a/ampd/src/lib.rs +++ b/ampd/src/lib.rs @@ -1,4 +1,5 @@ use std::pin::Pin; +use std::str::FromStr; use std::time::Duration; use block_height_monitor::BlockHeightMonitor; @@ -13,6 +14,7 @@ use events::Event; use evm::finalizer::{pick, Finalization}; use evm::json_rpc::EthereumClient; use queue::queued_broadcaster::{QueuedBroadcaster, QueuedBroadcasterDriver}; +use starknet_providers::jsonrpc::HttpTransport; use state::StateUpdater; use thiserror::Error; use tofnd::grpc::{MultisigClient, SharableEcdsaClient}; @@ -22,6 +24,7 @@ use tokio_stream::Stream; use tokio_util::sync::CancellationToken; use tracing::info; use types::TMAddress; +use url::Url; use crate::asyncutil::task::{CancellableTask, TaskError, TaskGroup}; use crate::config::Config; @@ -344,17 +347,24 @@ where cosmwasm_contract, rpc_url, rpc_timeout: _, - } => self.create_handler_task( - "starknet-msg-verifier", - handlers::starknet_verify_msg::Handler::new( - worker.clone(), - cosmwasm_contract, - starknet::verifier::RPCMessageVerifier::new(rpc_url.as_str()), - self.broadcaster.client(), - self.block_height_monitor.latest_block_height(), - ), - stream_timeout, - ), + } => { + // let starknet_rpc_url = Url::from_str(rpc_url).unwrap(); + self.create_handler_task( + "starknet-msg-verifier", + handlers::starknet_verify_msg::Handler::new( + worker.clone(), + cosmwasm_contract, + starknet::json_rpc::Client::new_with_transport(HttpTransport::new( + &rpc_url.into(), + )) + .unwrap(), + // starknet::verifier::RPCMessageVerifier::new(rpc_url.as_str()), + self.broadcaster.client(), + self.block_height_monitor.latest_block_height(), + ), + stream_timeout, + ) + } }; self.event_processor = self.event_processor.add_task(task); } diff --git a/ampd/src/starknet/json_rpc.rs b/ampd/src/starknet/json_rpc.rs index 1823d4cbc..6da0cce6e 100644 --- a/ampd/src/starknet/json_rpc.rs +++ b/ampd/src/starknet/json_rpc.rs @@ -4,6 +4,7 @@ use std::str::FromStr; use async_trait::async_trait; +use error_stack::Report; use mockall::automock; use starknet_core::types::{ ExecutionResult, FieldElement, FromStrError, MaybePendingTransactionReceipt, TransactionReceipt, @@ -14,6 +15,8 @@ use thiserror::Error; use crate::starknet::events::contract_call::ContractCallEvent; +type Result = error_stack::Result; + #[derive(Debug, Error)] pub enum StarknetClientError { #[error(transparent)] @@ -44,54 +47,56 @@ where /// Constructor. /// Expects URL of any JSON RPC entry point of Starknet, which you can find /// as constants in the `networks.rs` module - pub fn new(transport: T) -> Result { + pub fn new_with_transport(transport: T) -> Result { Ok(Client { client: JsonRpcClient::new(transport), }) } } +/// A trait for fetching a ContractCall event, by a given tx_hash +/// and parsing parsing it into +/// `crate::starknet::events::contract_call::ContractCallEvent` #[automock] #[async_trait] -pub trait StarknetClient -where - T: JsonRpcTransport + Send + Sync + 'static, -{ - async fn get_event_by_hash( - &self, - tx_hash: &str, - ) -> Result, StarknetClientError>; +pub trait StarknetClient { + /// Attempts to fetch a ContractCall event, by a given `tx_hash`. + /// Returns a tuple `(tx_hash, event)` or a `StarknetClientError`. + async fn get_event_by_hash(&self, tx_hash: &str) + -> Result>; } #[async_trait] -impl StarknetClient for Client +impl StarknetClient for Client where T: JsonRpcTransport + Send + Sync + 'static, { - /// Using given transaction hash, tries to fetch it from given - /// `starknet_url`. Returns error if request fails, `false` if internal - /// error returned by querry and `true` if transaction found async fn get_event_by_hash( &self, tx_hash: &str, - ) -> Result, StarknetClientError> { - let tx_hash_felt = FieldElement::from_str(tx_hash)?; + ) -> Result> { + let tx_hash_felt = + FieldElement::from_str(tx_hash).map_err(StarknetClientError::FeltFromString)?; // TODO: Check ACCEPTED ON L1 times and decide if we should use it // // Finality status is always at least ACCEPTED_ON_L2 and this is what we're // looking for, because ACCEPTED_ON_L1 (Ethereum) will take a very long time. - let receipt_type = self.client.get_transaction_receipt(tx_hash_felt).await?; + let receipt_type = self + .client + .get_transaction_receipt(tx_hash_felt) + .await + .map_err(StarknetClientError::FetchingReceipt)?; if *receipt_type.execution_result() != ExecutionResult::Succeeded { - return Err(StarknetClientError::UnsuccessfulTx); + return Err(Report::new(StarknetClientError::UnsuccessfulTx)); } let event: Option<(String, ContractCallEvent)> = match receipt_type { // TODO: There is also a PendingReceipt type. Should we handle it? MaybePendingTransactionReceipt::Receipt(receipt) => match receipt { TransactionReceipt::Invoke(tx) => { - // There should be only one ContractCall event per gateway tx + // NOTE: There should be only one ContractCall event per gateway tx tx.events .iter() .filter_map(|e| { @@ -134,7 +139,7 @@ mod test { #[tokio::test] async fn invalid_tx_hash_stirng() { - let mock_client = Client::new(ValidMockTransport).unwrap(); + let mock_client = Client::new_with_transport(ValidMockTransport).unwrap(); let contract_call_event = mock_client.get_event_by_hash("not a valid felt").await; assert!(contract_call_event.is_err()); @@ -142,7 +147,7 @@ mod test { #[tokio::test] async fn deploy_account_tx_fetch() { - let mock_client = Client::new(DeployAccountMockTransport).unwrap(); + let mock_client = Client::new_with_transport(DeployAccountMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; @@ -152,7 +157,7 @@ mod test { #[tokio::test] async fn deploy_tx_fetch() { - let mock_client = Client::new(DeployMockTransport).unwrap(); + let mock_client = Client::new_with_transport(DeployMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; @@ -162,7 +167,7 @@ mod test { #[tokio::test] async fn l1_handler_tx_fetch() { - let mock_client = Client::new(L1HandlerMockTransport).unwrap(); + let mock_client = Client::new_with_transport(L1HandlerMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; @@ -172,7 +177,7 @@ mod test { #[tokio::test] async fn declare_tx_fetch() { - let mock_client = Client::new(DeclareMockTransport).unwrap(); + let mock_client = Client::new_with_transport(DeclareMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; @@ -182,7 +187,8 @@ mod test { #[tokio::test] async fn invalid_contract_call_event_tx_fetch() { - let mock_client = Client::new(InvalidContractCallEventMockTransport).unwrap(); + let mock_client = + Client::new_with_transport(InvalidContractCallEventMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; @@ -192,7 +198,7 @@ mod test { #[tokio::test] async fn no_events_tx_fetch() { - let mock_client = Client::new(NoEventsMockTransport).unwrap(); + let mock_client = Client::new_with_transport(NoEventsMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; @@ -202,20 +208,19 @@ mod test { #[tokio::test] async fn reverted_tx_fetch() { - let mock_client = Client::new(RevertedMockTransport).unwrap(); + let mock_client = Client::new_with_transport(RevertedMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; - assert!(matches!( - contract_call_event.unwrap_err(), - StarknetClientError::UnsuccessfulTx - )); + assert!(contract_call_event + .unwrap_err() + .contains::()); } #[tokio::test] async fn failing_tx_fetch() { - let mock_client = Client::new(FailingMockTransport).unwrap(); + let mock_client = Client::new_with_transport(FailingMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await; @@ -225,7 +230,7 @@ mod test { #[tokio::test] async fn successful_tx_fetch() { - let mock_client = Client::new(ValidMockTransport).unwrap(); + let mock_client = Client::new_with_transport(ValidMockTransport).unwrap(); let contract_call_event = mock_client .get_event_by_hash(FieldElement::ONE.to_string().as_str()) .await @@ -239,6 +244,8 @@ mod test { assert_eq!( contract_call_event.1, ContractCallEvent { + from_contract_addr: + "0x0000000000000000000000000000000000000000000000000000000000000002".to_owned(), destination_address: String::from("hello"), destination_chain: String::from("destination_chain"), source_address: String::from( @@ -638,7 +645,7 @@ mod test { \"jsonrpc\": \"2.0\", \"result\": { \"type\": \"INVOKE\", - \"transaction_hash\": \"0x000000000000000000000000000000000000000000000000000000000000001\", + \"transaction_hash\": \"0x0000000000000000000000000000000000000000000000000000000000000001\", \"actual_fee\": { \"amount\": \"0x3062e4c46d4\", \"unit\": \"WEI\" @@ -650,7 +657,7 @@ mod test { \"messages_sent\": [], \"events\": [ { - \"from_address\": \"0x000000000000000000000000000000000000000000000000000000000000002\", + \"from_address\": \"0x0000000000000000000000000000000000000000000000000000000000000002\", \"keys\": [ \"0x034d074b86d78f064ec0a29639fcfab989c7a3ea6343653633624b2df9ec08f6\", \"0x00000000000000000000000000000064657374696e6174696f6e5f636861696e\" diff --git a/ampd/src/starknet/types/array_span.rs b/ampd/src/starknet/types/array_span.rs index 9f57913bd..e49f05557 100644 --- a/ampd/src/starknet/types/array_span.rs +++ b/ampd/src/starknet/types/array_span.rs @@ -110,7 +110,7 @@ mod array_span_tests { #[test] fn try_from_failed_to_parse_elements_length_to_u32() { - // the string "hello", but element counte bigger than u32::max + // the string "hello", but element count is bigger than u32::max let data: Result, FromStrError> = vec![ "0x00FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", "0x0000000000000000000000000000000000000000000000000000000000000068", diff --git a/ampd/src/starknet/verifier.rs b/ampd/src/starknet/verifier.rs index 810909db0..f89777cc5 100644 --- a/ampd/src/starknet/verifier.rs +++ b/ampd/src/starknet/verifier.rs @@ -1,76 +1,76 @@ use axelar_wasm_std::voting::Vote; -use mockall::automock; -use starknet_providers::jsonrpc::HttpTransport; -use thiserror::Error; -use tonic::async_trait; -use url::Url; use super::events::contract_call::ContractCallEvent; -use super::json_rpc::{Client, StarknetClientError}; use crate::handlers::starknet_verify_msg::Message; -use crate::starknet::json_rpc::StarknetClient; -#[derive(Error, Debug)] -pub enum VerifierError { - #[error("JSON-RPC error")] - JsonRPC, - #[error("block number missing in JSON-RPC response for finalized block")] - MissBlockNumber, - #[error("failed to fetch event: {0}")] - FetchEvent(#[from] StarknetClientError), +/// Attempts to fetch the tx provided in `axl_msg.tx_id`. +/// If successful, extracts and parses the ContractCall event +/// and compares it to the message from the relayer (via PollStarted event). +/// Also checks if the source_gateway_address with which +/// the voting verifier has been instantiated is the same address from +/// which the ContractCall event is coming. +pub fn verify_msg( + starknet_event: &ContractCallEvent, + msg: &Message, + source_gateway_address: &str, +) -> Vote { + dbg!(starknet_event); + dbg!(msg); + dbg!(source_gateway_address); + if *starknet_event == *msg && starknet_event.from_contract_addr == source_gateway_address { + Vote::SucceededOnChain + } else { + Vote::NotFound + } } -#[automock] -#[async_trait] -pub trait MessageVerifier { - async fn verify_msg(&self, axl_msg: &Message) -> core::result::Result; +impl PartialEq for ContractCallEvent { + fn eq(&self, axl_msg: &Message) -> bool { + axl_msg.source_address == self.source_address + && axl_msg.destination_chain == self.destination_chain + && axl_msg.destination_address == self.destination_address + && axl_msg.payload_hash == self.payload_hash + } } -pub struct RPCMessageVerifier { - client: Client, -} +#[cfg(test)] +mod tests { + use ethers::types::H256; + use starknet_core::utils::{parse_cairo_short_string, starknet_keccak}; + use starknet_providers::jsonrpc::HttpTransport; -impl RPCMessageVerifier { - pub fn new(url: &str) -> Self { - Self { - client: Client::new(HttpTransport::new(Url::parse(url).unwrap())).unwrap(), /* todoo scale error ? */ - } - } -} + use crate::starknet::events::contract_call::ContractCallEvent; + use crate::starknet::json_rpc::MockStarknetClient; -#[async_trait] -impl MessageVerifier for RPCMessageVerifier { - /// Verify that a tx with a certain `tx_hash` has happened on Starknet. - /// `tx_hash` comes from the the Axelar `Message::tx_id` - async fn verify_msg(&self, msg: &Message) -> core::result::Result { - match self - .client - .get_event_by_hash(msg.tx_id.as_str()) - .await - .map_err(VerifierError::FetchEvent)? - { - Some((event_tx_hash, contract_call_event)) => { - println!("MESSAGE {:?}", msg); - println!("CONTRACT_CALL_EVENT {:?}", contract_call_event); - println!("EVENT_TX_HASH {:?}", event_tx_hash); - if event_tx_hash == msg.tx_id && contract_call_event == msg - // && event.type_ == EventType::ContractCall.struct_tag(gateway_address) - { - Ok(Vote::SucceededOnChain) - } else { - Ok(Vote::FailedOnChain) - } - } - None => Ok(Vote::NotFound), + // "hello" as payload + // "hello" as destination address + // "some_contract_address" as source address + // "destination_chain" as destination_chain + fn mock_valid_event() -> ContractCallEvent { + let from_contract_addr = + parse_cairo_short_string(&starknet_keccak("some_contract_address".as_bytes())).unwrap(); + ContractCallEvent { + from_contract_addr, + destination_address: String::from("hello"), + destination_chain: String::from("destination_chain"), + source_address: String::from( + "0x00b3ff441a68610b30fd5e2abbf3a1548eb6ba6f3559f2862bf2dc757e5828ca", + ), + payload_hash: H256::from_slice(&[ + 28u8, 138, 255, 149, 6, 133, 194, 237, 75, 195, 23, 79, 52, 114, 40, 123, 86, 217, + 81, 123, 156, 148, 129, 39, 49, 154, 9, 167, 163, 109, 234, 200, + ]), } } -} -impl PartialEq<&Message> for ContractCallEvent { - fn eq(&self, axl_msg: &&Message) -> bool { - axl_msg.source_address == self.source_address - && axl_msg.destination_chain == self.destination_chain - && axl_msg.destination_address == self.destination_address - && axl_msg.payload_hash == self.payload_hash + fn shoud_verify_event() { + // let mut mock_client = MockStarknetClient::::new(); + // mock_client + // .expect_get_event_by_hash() + // .returning(|_| Ok(Some(("some_tx_hash".to_owned(), + // mock_valid_event())))); + // + // let verifier = RPCMessageVerifier::new("doesnt_matter"); + // verifier.client = mock_client; } }