Skip to content

Commit

Permalink
update: clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
heemankv committed Nov 8, 2024
1 parent cde59b3 commit 98a5bc0
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 133 deletions.
4 changes: 3 additions & 1 deletion crates/orchestrator/src/alerts/aws_sns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::config::ProviderConfig;

#[derive(Debug, Clone)]
pub struct AWSSNSParams {
// TODO: convert to ARN type, and validate it
// NOTE: aws is using str to represent ARN : https://docs.aws.amazon.com/sdk-for-rust/latest/dg/rust_sns_code_examples.html
pub sns_arn: String,
}

Expand All @@ -23,7 +25,7 @@ pub struct AWSSNS {
}

impl AWSSNS {
pub async fn new_with_settings(aws_sns_params: &AWSSNSParams, provider_config: Arc<ProviderConfig>) -> Self {
pub async fn new_with_params(aws_sns_params: &AWSSNSParams, provider_config: Arc<ProviderConfig>) -> Self {
let config = provider_config.get_aws_client_or_panic();
Self { client: Client::new(config), topic_arn: aws_sns_params.sns_arn.clone() }
}
Expand Down
15 changes: 8 additions & 7 deletions crates/orchestrator/src/cli/aws_config.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use clap::Args;
use serde::Serialize;
use url::Url;

/// Parameters used to config AWS.
#[derive(Debug, Clone, Args, Serialize)]
#[group(requires_all = ["aws_access_key_id", "aws_secret_access_key", "aws_region"])]
pub struct AWSConfigParams {
pub struct AWSConfigCliArgs {
/// The access key ID.
#[arg(env = "AWS_ACCESS_KEY_ID", long)]
pub aws_access_key_id: String,
Expand All @@ -16,12 +17,12 @@ pub struct AWSConfigParams {
/// The region.
#[arg(env = "AWS_REGION", long)]
pub aws_region: String,
}

/// The endpoint URL.
#[arg(env = "AWS_ENDPOINT_URL", long, default_value = "http://localhost.localstack.cloud:4566")]
pub aws_endpoint_url: String,

/// The default region.
#[arg(env = "AWS_DEFAULT_REGION", long, default_value = "localhost")]
pub struct AWSConfigParams {
pub aws_access_key_id: String,
pub aws_secret_access_key: String,
pub aws_region: String,
pub aws_endpoint_url: Url,
pub aws_default_region: String,
}
102 changes: 55 additions & 47 deletions crates/orchestrator/src/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use alert::AlertParams;
use aws_config::AWSConfigParams;
use aws_config::{AWSConfigCliArgs, AWSConfigParams};
use clap::{ArgGroup, Parser};
use da::DaParams;
use database::DatabaseParams;
Expand Down Expand Up @@ -78,7 +78,7 @@ pub mod storage;
pub struct RunCmd {
// AWS Config
#[clap(flatten)]
pub aws_config_args: AWSConfigParams,
pub aws_config_args: AWSConfigCliArgs,

// Settlement Layer
#[command(flatten)]
Expand Down Expand Up @@ -129,52 +129,27 @@ pub struct RunCmd {
}

impl RunCmd {
pub fn validate_settlement_params(&self) -> Result<settlement::SettlementParams, String> {
match (self.ethereum_args.settle_on_ethereum, self.starknet_args.settle_on_starknet) {
(true, false) => {
// TODO: Ensure Starknet params are not provided
pub fn validate_aws_config_params(&self) -> Result<AWSConfigParams, String> {
let aws_endpoint_url = Url::parse("http://localhost.localstack.cloud:4566").unwrap();
let aws_default_region = "localhost".to_string();

// Get Ethereum params or error if none provided
// Either all the values are provided or panic
let ethereum_params = EthereumSettlementParams {
ethereum_rpc_url: self.ethereum_args.ethereum_rpc_url.clone().unwrap(),
ethereum_private_key: self.ethereum_args.ethereum_private_key.clone().unwrap(),
l1_core_contract_address: self.ethereum_args.l1_core_contract_address.clone().unwrap(),
starknet_operator_address: self.ethereum_args.starknet_operator_address.clone().unwrap(),
};
Ok(SettlementParams::Ethereum(ethereum_params))
}
(false, true) => {
// TODO: Ensure Ethereum params are not provided
tracing::warn!("Setting AWS_ENDPOINT_URL to {} for AWS SDK to use", aws_endpoint_url);
tracing::warn!("Setting AWS_DEFAULT_REGION to {} for Omniqueue to use", aws_default_region);

// Get Starknet params or error if none provided
// Either all the values are provided or panic
let starknet_params = StarknetSettlementParams {
starknet_rpc_url: self.starknet_args.starknet_rpc_url.clone().unwrap(),
starknet_private_key: self.starknet_args.starknet_private_key.clone().unwrap(),
starknet_account_address: self.starknet_args.starknet_account_address.clone().unwrap(),
starknet_cairo_core_contract_address: self
.starknet_args
.starknet_cairo_core_contract_address
.clone()
.unwrap(),
starknet_finality_retry_wait_in_secs: self
.starknet_args
.starknet_finality_retry_wait_in_secs
.unwrap(),
madara_binary_path: self.starknet_args.madara_binary_path.clone().unwrap(),
};
Ok(SettlementParams::Starknet(starknet_params))
}
(true, true) | (false, false) => Err("Exactly one settlement layer must be selected".to_string()),
}
Ok(AWSConfigParams {
aws_access_key_id: self.aws_config_args.aws_access_key_id.clone(),
aws_secret_access_key: self.aws_config_args.aws_secret_access_key.clone(),
aws_region: self.aws_config_args.aws_region.clone(),
aws_endpoint_url,
aws_default_region,
})
}

pub fn validate_storage_params(&self) -> Result<StorageParams, String> {
if self.aws_s3_args.aws_s3 {
Ok(StorageParams::AWSS3(AWSS3Params { bucket_name: self.aws_s3_args.bucket_name.clone().unwrap() }))
pub fn validate_alert_params(&self) -> Result<AlertParams, String> {
if self.aws_sns_args.aws_sns {
Ok(AlertParams::AWSSNS(AWSSNSParams { sns_arn: self.aws_sns_args.sns_arn.clone().unwrap() }))
} else {
Err("Only AWS S3 is supported as of now".to_string())
Err("Only AWS SNS is supported as of now".to_string())
}
}

Expand All @@ -190,11 +165,11 @@ impl RunCmd {
}
}

pub fn validate_alert_params(&self) -> Result<AlertParams, String> {
if self.aws_sns_args.aws_sns {
Ok(AlertParams::AWSSNS(AWSSNSParams { sns_arn: self.aws_sns_args.sns_arn.clone().unwrap() }))
pub fn validate_storage_params(&self) -> Result<StorageParams, String> {
if self.aws_s3_args.aws_s3 {
Ok(StorageParams::AWSS3(AWSS3Params { bucket_name: self.aws_s3_args.bucket_name.clone().unwrap() }))
} else {
Err("Only AWS SNS is supported as of now".to_string())
Err("Only AWS S3 is supported as of now".to_string())
}
}

Expand All @@ -219,6 +194,39 @@ impl RunCmd {
}
}

pub fn validate_settlement_params(&self) -> Result<settlement::SettlementParams, String> {
match (self.ethereum_args.settle_on_ethereum, self.starknet_args.settle_on_starknet) {
(true, false) => {
let ethereum_params = EthereumSettlementParams {
ethereum_rpc_url: self.ethereum_args.ethereum_rpc_url.clone().unwrap(),
ethereum_private_key: self.ethereum_args.ethereum_private_key.clone().unwrap(),
l1_core_contract_address: self.ethereum_args.l1_core_contract_address.clone().unwrap(),
starknet_operator_address: self.ethereum_args.starknet_operator_address.clone().unwrap(),
};
Ok(SettlementParams::Ethereum(ethereum_params))
}
(false, true) => {
let starknet_params = StarknetSettlementParams {
starknet_rpc_url: self.starknet_args.starknet_rpc_url.clone().unwrap(),
starknet_private_key: self.starknet_args.starknet_private_key.clone().unwrap(),
starknet_account_address: self.starknet_args.starknet_account_address.clone().unwrap(),
starknet_cairo_core_contract_address: self
.starknet_args
.starknet_cairo_core_contract_address
.clone()
.unwrap(),
starknet_finality_retry_wait_in_secs: self
.starknet_args
.starknet_finality_retry_wait_in_secs
.unwrap(),
madara_binary_path: self.starknet_args.madara_binary_path.clone().unwrap(),
};
Ok(SettlementParams::Starknet(starknet_params))
}
(true, true) | (false, false) => Err("Exactly one settlement layer must be selected".to_string()),
}
}

pub fn validate_prover_params(&self) -> Result<ProverParams, String> {
if self.sharp_args.sharp {
Ok(ProverParams::Sharp(SharpParams {
Expand Down
63 changes: 19 additions & 44 deletions crates/orchestrator/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::routes::ServerParams;
/// by calling `config` function.
pub struct Config {
/// The orchestrator config
orchestrator_config: OrchestratorConfig,
orchestrator_params: OrchestratorParams,
/// The starknet client to get data from the node
starknet_client: Arc<JsonRpcClient<HttpTransport>>,
/// The DA client to interact with the DA layer
Expand All @@ -70,7 +70,7 @@ pub struct ServiceParams {
pub min_block_to_process: Option<u64>,
}

pub struct OrchestratorConfig {
pub struct OrchestratorParams {
pub madara_rpc_url: Url,
pub snos_config: SNOSParams,
pub service_config: ServiceParams,
Expand Down Expand Up @@ -108,17 +108,17 @@ pub async fn get_aws_config(aws_config: &AWSConfigParams) -> SdkConfig {
pub async fn init_config(run_cmd: &RunCmd) -> color_eyre::Result<Arc<Config>> {
dotenv().ok();

let aws_config = &run_cmd.aws_config_args;
let aws_config = &run_cmd.validate_aws_config_params().expect("Failed to validate AWS config params");
let provider_config = Arc::new(ProviderConfig::AWS(Box::new(get_aws_config(aws_config).await)));

let orchestrator_config = OrchestratorConfig {
let orchestrator_params = OrchestratorParams {
madara_rpc_url: run_cmd.madara_rpc_url.clone(),
snos_config: run_cmd.validate_snos_params().expect("Failed to validate SNOS params"),
service_config: run_cmd.validate_service_params().expect("Failed to validate service params"),
server_config: run_cmd.validate_server_params().expect("Failed to validate server params"),
};

let provider = JsonRpcClient::new(HttpTransport::new(orchestrator_config.madara_rpc_url.clone()));
let provider = JsonRpcClient::new(HttpTransport::new(orchestrator_params.madara_rpc_url.clone()));

// init database
let database_params =
Expand Down Expand Up @@ -156,7 +156,7 @@ pub async fn init_config(run_cmd: &RunCmd) -> color_eyre::Result<Arc<Config>> {
let queue = build_queue_client(&queue_params);

Ok(Arc::new(Config::new(
orchestrator_config,
orchestrator_params,
Arc::new(provider),
da_client,
prover_client,
Expand All @@ -172,7 +172,7 @@ impl Config {
/// Create a new config
#[allow(clippy::too_many_arguments)]
pub fn new(
orchestrator_config: OrchestratorConfig,
orchestrator_params: OrchestratorParams,
starknet_client: Arc<JsonRpcClient<HttpTransport>>,
da_client: Box<dyn DaClient>,
prover_client: Box<dyn ProverClient>,
Expand All @@ -183,7 +183,7 @@ impl Config {
alerts: Box<dyn Alerts>,
) -> Self {
Self {
orchestrator_config,
orchestrator_params,
starknet_client,
da_client,
prover_client,
Expand All @@ -197,22 +197,22 @@ impl Config {

/// Returns the starknet rpc url
pub fn starknet_rpc_url(&self) -> &Url {
&self.orchestrator_config.madara_rpc_url
&self.orchestrator_params.madara_rpc_url
}

/// Returns the server config
pub fn server_config(&self) -> &ServerParams {
&self.orchestrator_config.server_config
&self.orchestrator_params.server_config
}

/// Returns the snos rpc url
pub fn snos_config(&self) -> &SNOSParams {
&self.orchestrator_config.snos_config
&self.orchestrator_params.snos_config
}

/// Returns the service config
pub fn service_config(&self) -> &ServiceParams {
&self.orchestrator_config.service_config
&self.orchestrator_params.service_config
}

/// Returns the starknet client
Expand Down Expand Up @@ -270,7 +270,7 @@ pub async fn build_da_client(da_params: &DaParams) -> Box<dyn DaClient + Send +
/// Builds the prover service based on the environment variable PROVER_SERVICE
pub fn build_prover_service(prover_params: &ProverParams) -> Box<dyn ProverClient> {
match prover_params {
ProverParams::Sharp(sharp_params) => Box::new(SharpProverService::new_with_settings(sharp_params)),
ProverParams::Sharp(sharp_params) => Box::new(SharpProverService::new_with_params(sharp_params)),
}
}

Expand All @@ -282,11 +282,11 @@ pub async fn build_settlement_client(
SettlementParams::Ethereum(ethereum_settlement_params) => {
#[cfg(not(feature = "testing"))]
{
Ok(Box::new(EthereumSettlementClient::new_with_settings(ethereum_settlement_params)))
Ok(Box::new(EthereumSettlementClient::new_with_params(ethereum_settlement_params)))
}
#[cfg(feature = "testing")]
{
Ok(Box::new(EthereumSettlementClient::with_test_settings(
Ok(Box::new(EthereumSettlementClient::with_test_params(
RootProvider::new_http(ethereum_settlement_params.ethereum_rpc_url.clone()),
Address::from_str(&ethereum_settlement_params.l1_core_contract_address)?,
ethereum_settlement_params.ethereum_rpc_url.clone(),
Expand All @@ -295,42 +295,17 @@ pub async fn build_settlement_client(
}
}
SettlementParams::Starknet(starknet_settlement_params) => {
Ok(Box::new(StarknetSettlementClient::new_with_settings(starknet_settlement_params).await))
Ok(Box::new(StarknetSettlementClient::new_with_params(starknet_settlement_params).await))
}
}

// match settlement_params {
// "ethereum" => {
// #[cfg(not(feature = "testing"))]
// {
// Ok(Box::new(EthereumSettlementClient::new_with_settings(settings_provider)))
// }
// #[cfg(feature = "testing")]
// {
// Ok(Box::new(EthereumSettlementClient::with_test_settings(
//
// RootProvider::new_http(get_env_var_or_panic("MADARA_ORCHESTRATOR_ETHEREUM_SETTLEMENT_RPC_URL"
// ).as_str().parse()?),
// Address::from_str(&get_env_var_or_panic("MADARA_ORCHESTRATOR_L1_CORE_CONTRACT_ADDRESS"))?,
//
// Url::from_str(get_env_var_or_panic("MADARA_ORCHESTRATOR_ETHEREUM_SETTLEMENT_RPC_URL").
// as_str())?,
//
// Some(Address::from_str(get_env_var_or_panic("MADARA_ORCHESTRATOR_STARKNET_OPERATOR_ADDRESS").
// as_str())?), )))
// }
// }
// "starknet" =>
// Ok(Box::new(StarknetSettlementClient::new_with_settings(settings_provider).await)), _
// => panic!("Unsupported Settlement layer"), }
}

pub async fn build_storage_client(
data_storage_params: &StorageParams,
provider_config: Arc<ProviderConfig>,
) -> Box<dyn DataStorage + Send + Sync> {
match data_storage_params {
StorageParams::AWSS3(aws_s3_params) => Box::new(AWSS3::new_with_settings(aws_s3_params, provider_config).await),
StorageParams::AWSS3(aws_s3_params) => Box::new(AWSS3::new_with_params(aws_s3_params, provider_config).await),
}
}

Expand All @@ -341,7 +316,7 @@ pub async fn build_alert_client(
match alert_params {
AlertParams::AWSSNS(aws_sns_params) => {
println!("Building alert client {}", aws_sns_params.sns_arn);
Box::new(AWSSNS::new_with_settings(aws_sns_params, provider_config).await)
Box::new(AWSSNS::new_with_params(aws_sns_params, provider_config).await)
}
}
}
Expand All @@ -354,6 +329,6 @@ pub fn build_queue_client(queue_params: &QueueParams) -> Box<dyn QueueProvider +

pub async fn build_database_client(database_params: &DatabaseParams) -> Box<dyn Database + Send + Sync> {
match database_params {
DatabaseParams::MongoDB(mongodb_params) => Box::new(MongoDb::new_with_settings(mongodb_params).await),
DatabaseParams::MongoDB(mongodb_params) => Box::new(MongoDb::new_with_params(mongodb_params).await),
}
}
2 changes: 1 addition & 1 deletion crates/orchestrator/src/data_storage/aws_s3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct AWSS3 {
/// - initializing a new AWS S3 client
impl AWSS3 {
/// To init the struct with main settings
pub async fn new_with_settings(s3_config: &AWSS3Params, provider_config: Arc<ProviderConfig>) -> Self {
pub async fn new_with_params(s3_config: &AWSS3Params, provider_config: Arc<ProviderConfig>) -> Self {
let aws_config = provider_config.get_aws_client_or_panic();
// Building AWS S3 config
let mut s3_config_builder = aws_sdk_s3::config::Builder::from(aws_config);
Expand Down
2 changes: 1 addition & 1 deletion crates/orchestrator/src/database/mongodb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct MongoDb {
}

impl MongoDb {
pub async fn new_with_settings(mongodb_params: &MongoDBParams) -> Self {
pub async fn new_with_params(mongodb_params: &MongoDBParams) -> Self {
let mut client_options =
ClientOptions::parse(mongodb_params.connection_url.clone()).await.expect("Failed to parse MongoDB Url");
// Set the server_api field of the client_options object to set the version of the Stable API on the
Expand Down
2 changes: 0 additions & 2 deletions crates/orchestrator/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ async fn main() {
// TODO: could this be an ARC ?
let run_cmd: RunCmd = RunCmd::parse();

println!("{:?}", run_cmd.aws_sqs_args.queue_base_url);

// Analytics Setup
let instrumentation_params = run_cmd.validate_instrumentation_params().expect("Invalid instrumentation params");
let meter_provider = setup_analytics(&instrumentation_params);
Expand Down
Loading

0 comments on commit 98a5bc0

Please sign in to comment.