From b822096af2eb339d58902d9bad9806ea3c34e048 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Ko=C5=82aczkowski?= Date: Sat, 17 Aug 2024 17:50:31 +0200 Subject: [PATCH] Split Context to LocalContext and GlobalContext --- src/exec/workload.rs | 37 ++++++++++---------- src/main.rs | 4 +-- src/scripting/connect.rs | 6 ++-- src/scripting/context.rs | 58 +++++++++++++++++++++++------- src/scripting/functions.rs | 72 +++++++++++++++++++++++++++----------- src/scripting/mod.rs | 25 +++++++++---- src/scripting/rng.rs | 51 +++++++++++++++++++++++++++ 7 files changed, 191 insertions(+), 62 deletions(-) create mode 100644 src/scripting/rng.rs diff --git a/src/exec/workload.rs b/src/exec/workload.rs index ade7692..1d378da 100644 --- a/src/exec/workload.rs +++ b/src/exec/workload.rs @@ -9,7 +9,7 @@ use std::time::Instant; use crate::error::LatteError; use crate::scripting::cass_error::{CassError, CassErrorKind}; -use crate::scripting::context::Context; +use crate::scripting::context::{GlobalContext, LocalContext}; use crate::stats::latency::LatencyDistributionRecorder; use crate::stats::session::SessionStats; use rand::distributions::{Distribution, WeightedIndex}; @@ -24,15 +24,15 @@ use rune::{vm_try, Any, Diagnostics, Source, Sources, ToValue, Unit, Value, Vm}; use serde::{Deserialize, Serialize}; use try_lock::TryLock; -/// Wraps a reference to Session that can be converted to a Rune `Value` +/// Wraps a reference to `Context` that can be converted to a Rune `Value` /// and passed as one of `Args` arguments to a function. -struct SessionRef<'a> { - context: &'a Context, +struct ContextRef<'a> { + context: &'a GlobalContext, } -impl SessionRef<'_> { - pub fn new(context: &Context) -> SessionRef { - SessionRef { context } +impl ContextRef<'_> { + pub fn new(context: &GlobalContext) -> ContextRef { + ContextRef { context } } } @@ -44,7 +44,7 @@ impl SessionRef<'_> { /// possible that the underlying `Session` gets dropped before the `Value` produced by this trait /// implementation and the compiler is not going to catch that. /// The receiver of a `Value` must ensure that it is dropped before `Session`! -impl<'a> ToValue for SessionRef<'a> { +impl<'a> ToValue for ContextRef<'a> { fn to_value(self) -> VmResult { let obj = unsafe { AnyObj::from_ref(self.context) }; VmResult::Ok(Value::from(vm_try!(Shared::new(obj)))) @@ -54,11 +54,11 @@ impl<'a> ToValue for SessionRef<'a> { /// Wraps a mutable reference to Session that can be converted to a Rune `Value` and passed /// as one of `Args` arguments to a function. struct ContextRefMut<'a> { - context: &'a mut Context, + context: &'a mut GlobalContext, } impl ContextRefMut<'_> { - pub fn new(context: &mut Context) -> ContextRefMut { + pub fn new(context: &mut GlobalContext) -> ContextRefMut { ContextRefMut { context } } } @@ -264,7 +264,7 @@ impl Program { /// Calls the script's `init` function. /// Called once at the beginning of the benchmark. /// Typically used to prepare statements. - pub async fn prepare(&mut self, context: &mut Context) -> Result<(), LatteError> { + pub async fn prepare(&mut self, context: &mut GlobalContext) -> Result<(), LatteError> { let context = ContextRefMut::new(context); self.async_call(&FnRef::new(PREPARE_FN), (context,)).await?; Ok(()) @@ -272,7 +272,7 @@ impl Program { /// Calls the script's `schema` function. /// Typically used to create database schema. - pub async fn schema(&mut self, context: &mut Context) -> Result<(), LatteError> { + pub async fn schema(&mut self, context: &mut GlobalContext) -> Result<(), LatteError> { let context = ContextRefMut::new(context); self.async_call(&FnRef::new(SCHEMA_FN), (context,)).await?; Ok(()) @@ -280,7 +280,7 @@ impl Program { /// Calls the script's `erase` function. /// Typically used to remove the data from the database before running the benchmark. - pub async fn erase(&mut self, context: &mut Context) -> Result<(), LatteError> { + pub async fn erase(&mut self, context: &mut GlobalContext) -> Result<(), LatteError> { let context = ContextRefMut::new(context); self.async_call(&FnRef::new(ERASE_FN), (context,)).await?; Ok(()) @@ -412,14 +412,14 @@ impl FnStatsCollector { } pub struct Workload { - context: Context, + context: GlobalContext, program: Program, router: FunctionRouter, state: TryLock, } impl Workload { - pub fn new(context: Context, program: Program, functions: &[(FnRef, f64)]) -> Workload { + pub fn new(context: GlobalContext, program: Program, functions: &[(FnRef, f64)]) -> Workload { let state = FnStatsCollector::new(functions.iter().map(|x| x.0.clone())); Workload { context, @@ -448,9 +448,10 @@ impl Workload { pub async fn run(&self, cycle: i64) -> Result<(i64, Instant), LatteError> { let start_time = Instant::now(); let mut rng = SmallRng::seed_from_u64(cycle as u64); - let context = SessionRef::new(&self.context); + let global_ctx = ContextRef::new(&self.context); + let local_ctx = LocalContext::new(cycle, global_ctx.to_value().into_result().unwrap()); let function = self.router.select(&mut rng); - let result = self.program.async_call(function, (context, cycle)).await; + let result = self.program.async_call(function, (local_ctx, cycle)).await; let end_time = Instant::now(); let mut state = self.state.try_lock().unwrap(); let duration = end_time - start_time; @@ -474,7 +475,7 @@ impl Workload { /// Returns the reference to the contained context. /// Allows to e.g. access context stats. - pub fn context(&self) -> &Context { + pub fn context(&self) -> &GlobalContext { &self.context } diff --git a/src/main.rs b/src/main.rs index 7c8a1c1..111931e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,7 +31,7 @@ use crate::error::{LatteError, Result}; use crate::exec::{par_execute, ExecutionOptions}; use crate::report::{PathAndSummary, Report, RunConfigCmp}; use crate::scripting::connect::ClusterInfo; -use crate::scripting::context::Context; +use crate::scripting::context::GlobalContext; use crate::stats::{BenchmarkCmp, BenchmarkStats, Recorder}; use exec::cycle::BoundedCycleCounter; use exec::progress::Progress; @@ -109,7 +109,7 @@ fn find_workload(workload: &Path) -> PathBuf { } /// Connects to the server and returns the session -async fn connect(conf: &ConnectionConf) -> Result<(Context, Option)> { +async fn connect(conf: &ConnectionConf) -> Result<(GlobalContext, Option)> { eprintln!("info: Connecting to {:?}... ", conf.addresses); let session = scripting::connect::connect(conf).await?; let cluster_info = session.cluster_info().await?; diff --git a/src/scripting/connect.rs b/src/scripting/connect.rs index 6e22df6..2d945f9 100644 --- a/src/scripting/connect.rs +++ b/src/scripting/connect.rs @@ -1,6 +1,6 @@ use crate::config::ConnectionConf; use crate::scripting::cass_error::{CassError, CassErrorKind}; -use crate::scripting::context::Context; +use crate::scripting::context::GlobalContext; use openssl::ssl::{SslContext, SslContextBuilder, SslFiletype, SslMethod}; use scylla::load_balancing::DefaultPolicy; use scylla::transport::session::PoolSize; @@ -25,7 +25,7 @@ fn ssl_context(conf: &&ConnectionConf) -> Result, CassError> } /// Configures connection to Cassandra. -pub async fn connect(conf: &ConnectionConf) -> Result { +pub async fn connect(conf: &ConnectionConf) -> Result { let mut policy_builder = DefaultPolicy::builder().token_aware(true); if let Some(dc) = &conf.datacenter { policy_builder = policy_builder @@ -47,7 +47,7 @@ pub async fn connect(conf: &ConnectionConf) -> Result { .build() .await .map_err(|e| CassError(CassErrorKind::FailedToConnect(conf.addresses.clone(), e)))?; - Ok(Context::new(scylla_session, conf.retry_strategy)) + Ok(GlobalContext::new(scylla_session, conf.retry_strategy)) } pub struct ClusterInfo { diff --git a/src/scripting/context.rs b/src/scripting/context.rs index b1e29cf..f07e53f 100644 --- a/src/scripting/context.rs +++ b/src/scripting/context.rs @@ -3,10 +3,10 @@ use crate::error::LatteError; use crate::scripting::bind; use crate::scripting::cass_error::{CassError, CassErrorKind}; use crate::scripting::connect::ClusterInfo; +use crate::scripting::rng::Rng; use crate::stats::session::SessionStats; -use rand::prelude::ThreadRng; use rand::random; -use rune::runtime::{Object, Shared}; +use rune::runtime::{AnyObj, BorrowRef, Object, Shared}; use rune::{Any, Value}; use scylla::prepared_statement::PreparedStatement; use scylla::transport::errors::{DbError, QueryError}; @@ -22,7 +22,23 @@ use try_lock::TryLock; /// This is the main object that a workload script uses to interface with the outside world. /// It also tracks query execution metrics such as number of requests, rows, response times etc. #[derive(Any)] -pub struct Context { +pub struct LocalContext { + #[rune(get)] + pub cycle: i64, + #[rune(get)] + pub rng: Value, + #[rune(get)] + pub global: Value, +} + +impl From for Value { + fn from(value: LocalContext) -> Self { + Value::Any(Shared::new(AnyObj::new(value).unwrap()).unwrap()) + } +} + +#[derive(Any)] +pub struct GlobalContext { start_time: TryLock, session: Arc, statements: HashMap>, @@ -32,7 +48,6 @@ pub struct Context { pub load_cycle_count: u64, #[rune(get)] pub data: Value, - pub rng: ThreadRng, } // Needed, because Rune `Value` is !Send, as it may contain some internal pointers. @@ -41,12 +56,14 @@ pub struct Context { // To make it safe, the same `Context` is never used by more than one thread at once, and // we make sure in `clone` to make a deep copy of the `data` field by serializing // and deserializing it, so no pointers could get through. -unsafe impl Send for Context {} -unsafe impl Sync for Context {} - -impl Context { - pub fn new(session: scylla::Session, retry_strategy: RetryStrategy) -> Context { - Context { +unsafe impl Send for LocalContext {} +unsafe impl Sync for LocalContext {} +unsafe impl Send for GlobalContext {} +unsafe impl Sync for GlobalContext {} + +impl GlobalContext { + pub fn new(session: scylla::Session, retry_strategy: RetryStrategy) -> Self { + Self { start_time: TryLock::new(Instant::now()), session: Arc::new(session), statements: HashMap::new(), @@ -54,7 +71,6 @@ impl Context { retry_strategy, load_cycle_count: 0, data: Value::Object(Shared::new(Object::new()).unwrap()), - rng: rand::thread_rng(), } } @@ -65,13 +81,12 @@ impl Context { pub fn clone(&self) -> Result { let serialized = rmp_serde::to_vec(&self.data)?; let deserialized: Value = rmp_serde::from_slice(&serialized)?; - Ok(Context { + Ok(Self { session: self.session.clone(), statements: self.statements.clone(), stats: TryLock::new(SessionStats::default()), data: deserialized, start_time: TryLock::new(*self.start_time.try_lock().unwrap()), - rng: rand::thread_rng(), ..*self }) } @@ -188,6 +203,23 @@ impl Context { } } +impl LocalContext { + pub fn new(cycle: i64, global: Value) -> Self { + Self { + cycle, + global, + rng: Value::Any(Shared::new(AnyObj::new(Rng::with_seed(cycle)).unwrap()).unwrap()), + } + } + + pub fn global(&self) -> BorrowRef { + let Value::Any(obj) = &self.global else { + panic!("global must be an object") + }; + obj.downcast_borrow_ref().unwrap() + } +} + pub fn get_exponential_retry_interval( min_interval: Duration, max_interval: Duration, diff --git a/src/scripting/functions.rs b/src/scripting/functions.rs index 5eed86c..7fa3a0b 100644 --- a/src/scripting/functions.rs +++ b/src/scripting/functions.rs @@ -1,5 +1,5 @@ use crate::scripting::cass_error::CassError; -use crate::scripting::context::Context; +use crate::scripting::context::LocalContext; use crate::scripting::cql_types::{Int8, Uuid}; use crate::scripting::Resources; use chrono::Utc; @@ -9,7 +9,7 @@ use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; use rune::macros::{quote, MacroContext, TokenStream}; use rune::parse::Parser; -use rune::runtime::{Function, Mut, Ref, VmError, VmResult}; +use rune::runtime::{Function, Ref, VmError, VmResult}; use rune::{ast, vm_try, Value}; use statrs::distribution::{Normal, Uniform}; use std::collections::HashMap; @@ -236,26 +236,58 @@ pub fn read_resource_words(path: &str) -> io::Result> { .collect()) } -#[rune::function(instance)] -pub async fn prepare(mut ctx: Mut, key: Ref, cql: Ref) -> Result<(), CassError> { - ctx.prepare(&key, &cql).await -} +pub mod local { + use super::*; -#[rune::function(instance)] -pub async fn execute(ctx: Ref, cql: Ref) -> Result<(), CassError> { - ctx.execute(cql.deref()).await -} + #[rune::function(instance)] + pub async fn execute(ctx: Ref, cql: Ref) -> Result<(), CassError> { + ctx.global().execute(cql.deref()).await + } -#[rune::function(instance)] -pub async fn execute_prepared( - ctx: Ref, - key: Ref, - params: Value, -) -> Result<(), CassError> { - ctx.execute_prepared(&key, params).await + #[rune::function(instance)] + pub async fn execute_prepared( + ctx: Ref, + key: Ref, + params: Value, + ) -> Result<(), CassError> { + ctx.global().execute_prepared(&key, params).await + } + + #[rune::function(instance)] + pub fn elapsed_secs(ctx: &LocalContext) -> f64 { + ctx.global().elapsed_secs() + } } -#[rune::function(instance)] -pub fn elapsed_secs(ctx: &Context) -> f64 { - ctx.elapsed_secs() +pub mod global { + use super::*; + use crate::scripting::context::GlobalContext; + use rune::runtime::Mut; + #[rune::function(instance)] + pub async fn prepare( + mut ctx: Mut, + key: Ref, + cql: Ref, + ) -> Result<(), CassError> { + ctx.prepare(&key, &cql).await + } + + #[rune::function(instance)] + pub async fn execute(ctx: Ref, cql: Ref) -> Result<(), CassError> { + ctx.execute(cql.deref()).await + } + + #[rune::function(instance)] + pub async fn execute_prepared( + ctx: Ref, + key: Ref, + params: Value, + ) -> Result<(), CassError> { + ctx.execute_prepared(&key, params).await + } + + #[rune::function(instance)] + pub fn elapsed_secs(ctx: &GlobalContext) -> f64 { + ctx.elapsed_secs() + } } diff --git a/src/scripting/mod.rs b/src/scripting/mod.rs index bc3a09a..7f1d670 100644 --- a/src/scripting/mod.rs +++ b/src/scripting/mod.rs @@ -1,5 +1,6 @@ use crate::scripting::cass_error::CassError; -use crate::scripting::context::Context; +use crate::scripting::context::{GlobalContext, LocalContext}; +use crate::scripting::rng::Rng; use rune::{ContextError, Module}; use rust_embed::RustEmbed; use std::collections::HashMap; @@ -10,6 +11,7 @@ pub mod connect; pub mod context; mod cql_types; mod functions; +mod rng; #[derive(RustEmbed)] #[folder = "resources/"] @@ -24,11 +26,22 @@ fn try_install( params: HashMap, ) -> Result<(), ContextError> { let mut context_module = Module::default(); - context_module.ty::()?; - context_module.function_meta(functions::execute)?; - context_module.function_meta(functions::prepare)?; - context_module.function_meta(functions::execute_prepared)?; - context_module.function_meta(functions::elapsed_secs)?; + context_module.ty::()?; + context_module.ty::()?; + context_module.ty::()?; + + context_module.function_meta(functions::local::execute)?; + context_module.function_meta(functions::local::execute_prepared)?; + context_module.function_meta(functions::local::elapsed_secs)?; + + context_module.function_meta(functions::global::execute)?; + context_module.function_meta(functions::global::prepare)?; + context_module.function_meta(functions::global::execute_prepared)?; + context_module.function_meta(functions::global::elapsed_secs)?; + + context_module.function_meta(Rng::gen_range)?; + context_module.function_meta(Rng::gen_i64)?; + context_module.function_meta(Rng::gen_f64)?; let mut err_module = Module::default(); err_module.ty::()?; diff --git a/src/scripting/rng.rs b/src/scripting/rng.rs new file mode 100644 index 0000000..5059ff5 --- /dev/null +++ b/src/scripting/rng.rs @@ -0,0 +1,51 @@ +use rand::rngs::SmallRng; +use rand::{Rng as RRng, SeedableRng}; +use rune::runtime::{VmError, VmResult}; +use rune::Value; +use std::any::Any; + +#[derive(Debug, Clone, rune::Any)] +pub struct Rng(SmallRng); + +impl Rng { + pub fn with_seed(seed: i64) -> Self { + Self::with_rng(SmallRng::seed_from_u64(seed as u64)) + } + + pub fn with_rng(rng: SmallRng) -> Self { + Self(rng) + } + + #[rune::function] + pub fn gen_i64(&mut self) -> Value { + Value::Integer(self.0.gen()) + } + + #[rune::function] + pub fn gen_f64(&mut self) -> Value { + Value::Float(self.0.gen()) + } + + #[rune::function] + pub fn gen_range(&mut self, min: Value, max: Value) -> VmResult { + match (min, max) { + (Value::Integer(min), Value::Integer(max)) => { + VmResult::Ok(Value::Integer(self.0.gen_range(min..max))) + } + (Value::Float(min), Value::Float(max)) => { + VmResult::Ok(Value::Float(self.0.gen_range(min..max))) + } + (Value::Char(min), Value::Char(max)) => { + VmResult::Ok(Value::Char(self.0.gen_range(min..max))) + } + (Value::Byte(min), Value::Byte(max)) => { + VmResult::Ok(Value::Byte(self.0.gen_range(min..max))) + } + (min, max) => VmResult::Err(VmError::panic(format!( + "Invalid argument types: {:?}, {:?}", + min.type_id(), + max.type_id() + ))), + } + } +}