Skip to content

Commit

Permalink
Split Context to LocalContext and GlobalContext
Browse files Browse the repository at this point in the history
  • Loading branch information
pkolaczk committed Aug 17, 2024
1 parent 188b8dc commit b822096
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 62 deletions.
37 changes: 19 additions & 18 deletions src/exec/workload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::time::Instant;

use crate::error::LatteError;
use crate::scripting::cass_error::{CassError, CassErrorKind};
use crate::scripting::context::Context;
use crate::scripting::context::{GlobalContext, LocalContext};
use crate::stats::latency::LatencyDistributionRecorder;
use crate::stats::session::SessionStats;
use rand::distributions::{Distribution, WeightedIndex};
Expand All @@ -24,15 +24,15 @@ use rune::{vm_try, Any, Diagnostics, Source, Sources, ToValue, Unit, Value, Vm};
use serde::{Deserialize, Serialize};
use try_lock::TryLock;

/// Wraps a reference to Session that can be converted to a Rune `Value`
/// Wraps a reference to `Context` that can be converted to a Rune `Value`
/// and passed as one of `Args` arguments to a function.
struct SessionRef<'a> {
context: &'a Context,
struct ContextRef<'a> {
context: &'a GlobalContext,
}

impl SessionRef<'_> {
pub fn new(context: &Context) -> SessionRef {
SessionRef { context }
impl ContextRef<'_> {
pub fn new(context: &GlobalContext) -> ContextRef {
ContextRef { context }
}
}

Expand All @@ -44,7 +44,7 @@ impl SessionRef<'_> {
/// possible that the underlying `Session` gets dropped before the `Value` produced by this trait
/// implementation and the compiler is not going to catch that.
/// The receiver of a `Value` must ensure that it is dropped before `Session`!
impl<'a> ToValue for SessionRef<'a> {
impl<'a> ToValue for ContextRef<'a> {
fn to_value(self) -> VmResult<Value> {
let obj = unsafe { AnyObj::from_ref(self.context) };
VmResult::Ok(Value::from(vm_try!(Shared::new(obj))))
Expand All @@ -54,11 +54,11 @@ impl<'a> ToValue for SessionRef<'a> {
/// Wraps a mutable reference to Session that can be converted to a Rune `Value` and passed
/// as one of `Args` arguments to a function.
struct ContextRefMut<'a> {
context: &'a mut Context,
context: &'a mut GlobalContext,
}

impl ContextRefMut<'_> {
pub fn new(context: &mut Context) -> ContextRefMut {
pub fn new(context: &mut GlobalContext) -> ContextRefMut {
ContextRefMut { context }
}
}
Expand Down Expand Up @@ -264,23 +264,23 @@ impl Program {
/// Calls the script's `init` function.
/// Called once at the beginning of the benchmark.
/// Typically used to prepare statements.
pub async fn prepare(&mut self, context: &mut Context) -> Result<(), LatteError> {
pub async fn prepare(&mut self, context: &mut GlobalContext) -> Result<(), LatteError> {
let context = ContextRefMut::new(context);
self.async_call(&FnRef::new(PREPARE_FN), (context,)).await?;
Ok(())
}

/// Calls the script's `schema` function.
/// Typically used to create database schema.
pub async fn schema(&mut self, context: &mut Context) -> Result<(), LatteError> {
pub async fn schema(&mut self, context: &mut GlobalContext) -> Result<(), LatteError> {
let context = ContextRefMut::new(context);
self.async_call(&FnRef::new(SCHEMA_FN), (context,)).await?;
Ok(())
}

/// Calls the script's `erase` function.
/// Typically used to remove the data from the database before running the benchmark.
pub async fn erase(&mut self, context: &mut Context) -> Result<(), LatteError> {
pub async fn erase(&mut self, context: &mut GlobalContext) -> Result<(), LatteError> {
let context = ContextRefMut::new(context);
self.async_call(&FnRef::new(ERASE_FN), (context,)).await?;
Ok(())
Expand Down Expand Up @@ -412,14 +412,14 @@ impl FnStatsCollector {
}

pub struct Workload {
context: Context,
context: GlobalContext,
program: Program,
router: FunctionRouter,
state: TryLock<FnStatsCollector>,
}

impl Workload {
pub fn new(context: Context, program: Program, functions: &[(FnRef, f64)]) -> Workload {
pub fn new(context: GlobalContext, program: Program, functions: &[(FnRef, f64)]) -> Workload {
let state = FnStatsCollector::new(functions.iter().map(|x| x.0.clone()));
Workload {
context,
Expand Down Expand Up @@ -448,9 +448,10 @@ impl Workload {
pub async fn run(&self, cycle: i64) -> Result<(i64, Instant), LatteError> {
let start_time = Instant::now();
let mut rng = SmallRng::seed_from_u64(cycle as u64);
let context = SessionRef::new(&self.context);
let global_ctx = ContextRef::new(&self.context);
let local_ctx = LocalContext::new(cycle, global_ctx.to_value().into_result().unwrap());
let function = self.router.select(&mut rng);
let result = self.program.async_call(function, (context, cycle)).await;
let result = self.program.async_call(function, (local_ctx, cycle)).await;
let end_time = Instant::now();
let mut state = self.state.try_lock().unwrap();
let duration = end_time - start_time;
Expand All @@ -474,7 +475,7 @@ impl Workload {

/// Returns the reference to the contained context.
/// Allows to e.g. access context stats.
pub fn context(&self) -> &Context {
pub fn context(&self) -> &GlobalContext {
&self.context
}

Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::error::{LatteError, Result};
use crate::exec::{par_execute, ExecutionOptions};
use crate::report::{PathAndSummary, Report, RunConfigCmp};
use crate::scripting::connect::ClusterInfo;
use crate::scripting::context::Context;
use crate::scripting::context::GlobalContext;
use crate::stats::{BenchmarkCmp, BenchmarkStats, Recorder};
use exec::cycle::BoundedCycleCounter;
use exec::progress::Progress;
Expand Down Expand Up @@ -109,7 +109,7 @@ fn find_workload(workload: &Path) -> PathBuf {
}

/// Connects to the server and returns the session
async fn connect(conf: &ConnectionConf) -> Result<(Context, Option<ClusterInfo>)> {
async fn connect(conf: &ConnectionConf) -> Result<(GlobalContext, Option<ClusterInfo>)> {
eprintln!("info: Connecting to {:?}... ", conf.addresses);
let session = scripting::connect::connect(conf).await?;
let cluster_info = session.cluster_info().await?;
Expand Down
6 changes: 3 additions & 3 deletions src/scripting/connect.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::config::ConnectionConf;
use crate::scripting::cass_error::{CassError, CassErrorKind};
use crate::scripting::context::Context;
use crate::scripting::context::GlobalContext;
use openssl::ssl::{SslContext, SslContextBuilder, SslFiletype, SslMethod};
use scylla::load_balancing::DefaultPolicy;
use scylla::transport::session::PoolSize;
Expand All @@ -25,7 +25,7 @@ fn ssl_context(conf: &&ConnectionConf) -> Result<Option<SslContext>, CassError>
}

/// Configures connection to Cassandra.
pub async fn connect(conf: &ConnectionConf) -> Result<Context, CassError> {
pub async fn connect(conf: &ConnectionConf) -> Result<GlobalContext, CassError> {
let mut policy_builder = DefaultPolicy::builder().token_aware(true);
if let Some(dc) = &conf.datacenter {
policy_builder = policy_builder
Expand All @@ -47,7 +47,7 @@ pub async fn connect(conf: &ConnectionConf) -> Result<Context, CassError> {
.build()
.await
.map_err(|e| CassError(CassErrorKind::FailedToConnect(conf.addresses.clone(), e)))?;
Ok(Context::new(scylla_session, conf.retry_strategy))
Ok(GlobalContext::new(scylla_session, conf.retry_strategy))
}

pub struct ClusterInfo {
Expand Down
58 changes: 45 additions & 13 deletions src/scripting/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use crate::error::LatteError;
use crate::scripting::bind;
use crate::scripting::cass_error::{CassError, CassErrorKind};
use crate::scripting::connect::ClusterInfo;
use crate::scripting::rng::Rng;
use crate::stats::session::SessionStats;
use rand::prelude::ThreadRng;
use rand::random;
use rune::runtime::{Object, Shared};
use rune::runtime::{AnyObj, BorrowRef, Object, Shared};
use rune::{Any, Value};
use scylla::prepared_statement::PreparedStatement;
use scylla::transport::errors::{DbError, QueryError};
Expand All @@ -22,7 +22,23 @@ use try_lock::TryLock;
/// This is the main object that a workload script uses to interface with the outside world.
/// It also tracks query execution metrics such as number of requests, rows, response times etc.
#[derive(Any)]
pub struct Context {
pub struct LocalContext {
#[rune(get)]
pub cycle: i64,
#[rune(get)]
pub rng: Value,
#[rune(get)]
pub global: Value,
}

impl From<LocalContext> for Value {
fn from(value: LocalContext) -> Self {
Value::Any(Shared::new(AnyObj::new(value).unwrap()).unwrap())
}
}

#[derive(Any)]
pub struct GlobalContext {
start_time: TryLock<Instant>,
session: Arc<scylla::Session>,
statements: HashMap<String, Arc<PreparedStatement>>,
Expand All @@ -32,7 +48,6 @@ pub struct Context {
pub load_cycle_count: u64,
#[rune(get)]
pub data: Value,
pub rng: ThreadRng,
}

// Needed, because Rune `Value` is !Send, as it may contain some internal pointers.
Expand All @@ -41,20 +56,21 @@ pub struct Context {
// To make it safe, the same `Context` is never used by more than one thread at once, and
// we make sure in `clone` to make a deep copy of the `data` field by serializing
// and deserializing it, so no pointers could get through.
unsafe impl Send for Context {}
unsafe impl Sync for Context {}

impl Context {
pub fn new(session: scylla::Session, retry_strategy: RetryStrategy) -> Context {
Context {
unsafe impl Send for LocalContext {}
unsafe impl Sync for LocalContext {}
unsafe impl Send for GlobalContext {}
unsafe impl Sync for GlobalContext {}

impl GlobalContext {
pub fn new(session: scylla::Session, retry_strategy: RetryStrategy) -> Self {
Self {
start_time: TryLock::new(Instant::now()),
session: Arc::new(session),
statements: HashMap::new(),
stats: TryLock::new(SessionStats::new()),
retry_strategy,
load_cycle_count: 0,
data: Value::Object(Shared::new(Object::new()).unwrap()),
rng: rand::thread_rng(),
}
}

Expand All @@ -65,13 +81,12 @@ impl Context {
pub fn clone(&self) -> Result<Self, LatteError> {
let serialized = rmp_serde::to_vec(&self.data)?;
let deserialized: Value = rmp_serde::from_slice(&serialized)?;
Ok(Context {
Ok(Self {
session: self.session.clone(),
statements: self.statements.clone(),
stats: TryLock::new(SessionStats::default()),
data: deserialized,
start_time: TryLock::new(*self.start_time.try_lock().unwrap()),
rng: rand::thread_rng(),
..*self
})
}
Expand Down Expand Up @@ -188,6 +203,23 @@ impl Context {
}
}

impl LocalContext {
pub fn new(cycle: i64, global: Value) -> Self {
Self {
cycle,
global,
rng: Value::Any(Shared::new(AnyObj::new(Rng::with_seed(cycle)).unwrap()).unwrap()),
}
}

pub fn global(&self) -> BorrowRef<GlobalContext> {
let Value::Any(obj) = &self.global else {
panic!("global must be an object")
};
obj.downcast_borrow_ref().unwrap()
}
}

pub fn get_exponential_retry_interval(
min_interval: Duration,
max_interval: Duration,
Expand Down
72 changes: 52 additions & 20 deletions src/scripting/functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::scripting::cass_error::CassError;
use crate::scripting::context::Context;
use crate::scripting::context::LocalContext;
use crate::scripting::cql_types::{Int8, Uuid};
use crate::scripting::Resources;
use chrono::Utc;
Expand All @@ -9,7 +9,7 @@ use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use rune::macros::{quote, MacroContext, TokenStream};
use rune::parse::Parser;
use rune::runtime::{Function, Mut, Ref, VmError, VmResult};
use rune::runtime::{Function, Ref, VmError, VmResult};
use rune::{ast, vm_try, Value};
use statrs::distribution::{Normal, Uniform};
use std::collections::HashMap;
Expand Down Expand Up @@ -236,26 +236,58 @@ pub fn read_resource_words(path: &str) -> io::Result<Vec<String>> {
.collect())
}

#[rune::function(instance)]
pub async fn prepare(mut ctx: Mut<Context>, key: Ref<str>, cql: Ref<str>) -> Result<(), CassError> {
ctx.prepare(&key, &cql).await
}
pub mod local {
use super::*;

#[rune::function(instance)]
pub async fn execute(ctx: Ref<Context>, cql: Ref<str>) -> Result<(), CassError> {
ctx.execute(cql.deref()).await
}
#[rune::function(instance)]
pub async fn execute(ctx: Ref<LocalContext>, cql: Ref<str>) -> Result<(), CassError> {
ctx.global().execute(cql.deref()).await
}

#[rune::function(instance)]
pub async fn execute_prepared(
ctx: Ref<Context>,
key: Ref<str>,
params: Value,
) -> Result<(), CassError> {
ctx.execute_prepared(&key, params).await
#[rune::function(instance)]
pub async fn execute_prepared(
ctx: Ref<LocalContext>,
key: Ref<str>,
params: Value,
) -> Result<(), CassError> {
ctx.global().execute_prepared(&key, params).await
}

#[rune::function(instance)]
pub fn elapsed_secs(ctx: &LocalContext) -> f64 {
ctx.global().elapsed_secs()
}
}

#[rune::function(instance)]
pub fn elapsed_secs(ctx: &Context) -> f64 {
ctx.elapsed_secs()
pub mod global {
use super::*;
use crate::scripting::context::GlobalContext;
use rune::runtime::Mut;
#[rune::function(instance)]
pub async fn prepare(
mut ctx: Mut<GlobalContext>,
key: Ref<str>,
cql: Ref<str>,
) -> Result<(), CassError> {
ctx.prepare(&key, &cql).await
}

#[rune::function(instance)]
pub async fn execute(ctx: Ref<GlobalContext>, cql: Ref<str>) -> Result<(), CassError> {
ctx.execute(cql.deref()).await
}

#[rune::function(instance)]
pub async fn execute_prepared(
ctx: Ref<GlobalContext>,
key: Ref<str>,
params: Value,
) -> Result<(), CassError> {
ctx.execute_prepared(&key, params).await
}

#[rune::function(instance)]
pub fn elapsed_secs(ctx: &GlobalContext) -> f64 {
ctx.elapsed_secs()
}
}
Loading

0 comments on commit b822096

Please sign in to comment.