Skip to content

Implement exposing/enforcing coordinator for request #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ SCYLLA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
:SerialConsistencyTests.*\
:HeartbeatTests.*\
:PreparedTests.*\
:StatementNoClusterTests.*\
:StatementTests.*\
:NamedParametersTests.*\
:CassandraTypes/CassandraTypesTests/*.Integration_Cassandra_*\
:ControlConnectionTests.*\
Expand All @@ -27,6 +29,7 @@ SCYLLA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
:PreparedMetadataTests.*\
:UseKeyspaceCaseSensitiveTests.*\
:ServerSideFailureTests.*\
:ServerSideFailureThreeNodeTests.*\
:TimestampTests.*\
:MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\
:MetricsTests.Integration_Cassandra_Requests\
Expand Down Expand Up @@ -69,6 +72,8 @@ CASSANDRA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
:SerialConsistencyTests.*\
:HeartbeatTests.*\
:PreparedTests.*\
:StatementNoClusterTests.*\
:StatementTests.*\
:NamedParametersTests.*\
:CassandraTypes/CassandraTypesTests/*.Integration_Cassandra_*\
:ControlConnectionTests.*\
Expand All @@ -83,6 +88,7 @@ CASSANDRA_TEST_FILTER := $(subst ${SPACE},${EMPTY},ClusterTests.*\
:PreparedMetadataTests.*\
:UseKeyspaceCaseSensitiveTests.*\
:ServerSideFailureTests.*\
:ServerSideFailureThreeNodeTests.*\
:TimestampTests.*\
:MetricsTests.Integration_Cassandra_ErrorsRequestTimeouts\
:MetricsTests.Integration_Cassandra_Requests\
Expand Down
9 changes: 4 additions & 5 deletions scylla-rust-wrapper/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions scylla-rust-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ categories = ["database"]
license = "MIT OR Apache-2.0"

[dependencies]
scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.1.0", features = [
scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "32d179cb2", features = [
"openssl-010",
"metrics",
] }
Expand All @@ -34,7 +34,7 @@ bindgen = "0.65"
chrono = "0.4.20"

[dev-dependencies]
scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v1.1.0" }
scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "32d179cb2" }
bytes = "1.10.0"

assert_matches = "1.5.0"
Expand Down
69 changes: 46 additions & 23 deletions scylla-rust-wrapper/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ use crate::cass_error::CassErrorMessage;
use crate::cass_error::ToCassError;
use crate::execution_error::CassErrorResult;
use crate::prepared::CassPrepared;
use crate::query_result::CassResult;
use crate::query_result::{CassNode, CassResult};
use crate::types::*;
use crate::uuid::CassUuid;
use futures::future;
use std::future::Future;
use std::mem;
use std::os::raw::c_void;
use std::sync::{Arc, Condvar, Mutex};
use std::sync::{Arc, Condvar, Mutex, OnceLock};
use tokio::task::JoinHandle;
use tokio::time::Duration;

#[derive(Debug)]
pub enum CassResultValue {
Empty,
QueryResult(Arc<CassResult>),
Expand Down Expand Up @@ -50,14 +51,14 @@ impl BoundCallback {

#[derive(Default)]
struct CassFutureState {
value: Option<CassFutureResult>,
err_string: Option<String>,
callback: Option<BoundCallback>,
join_handle: Option<JoinHandle<()>>,
}

pub struct CassFuture {
state: Mutex<CassFutureState>,
result: OnceLock<CassFutureResult>,
wait_for_value: Condvar,
}

Expand Down Expand Up @@ -87,14 +88,18 @@ impl CassFuture {
) -> Arc<CassFuture> {
let cass_fut = Arc::new(CassFuture {
state: Mutex::new(Default::default()),
result: OnceLock::new(),
wait_for_value: Condvar::new(),
});
let cass_fut_clone = Arc::clone(&cass_fut);
let join_handle = RUNTIME.spawn(async move {
let r = fut.await;
let maybe_cb = {
let mut guard = cass_fut_clone.state.lock().unwrap();
guard.value = Some(r);
cass_fut_clone
.result
.set(r)
.expect("Tried to resolve future result twice!");
// Take the callback and call it after releasing the lock
guard.callback.take()
};
Expand All @@ -115,16 +120,17 @@ impl CassFuture {

pub fn new_ready(r: CassFutureResult) -> Arc<Self> {
Arc::new(CassFuture {
state: Mutex::new(CassFutureState {
value: Some(r),
..Default::default()
}),
state: Mutex::new(CassFutureState::default()),
result: OnceLock::from(r),
wait_for_value: Condvar::new(),
})
}

pub fn with_waited_result<T>(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T {
self.with_waited_state(|s| f(s.value.as_mut().unwrap()))
pub fn with_waited_result<'s, T>(&'s self, f: impl FnOnce(&'s CassFutureResult) -> T) -> T
where
T: 's,
{
self.with_waited_state(|_| f(self.result.get().unwrap()))
}

/// Awaits the future until completion.
Expand Down Expand Up @@ -153,7 +159,7 @@ impl CassFuture {
guard = self
.wait_for_value
.wait_while(guard, |state| {
state.value.is_none() && state.join_handle.is_none()
self.result.get().is_none() && state.join_handle.is_none()
})
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
Expand All @@ -171,10 +177,10 @@ impl CassFuture {

fn with_waited_result_timed<T>(
&self,
f: impl FnOnce(&mut CassFutureResult) -> T,
f: impl FnOnce(&CassFutureResult) -> T,
timeout_duration: Duration,
) -> Result<T, FutureError> {
self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration)
self.with_waited_state_timed(|_| f(self.result.get().unwrap()), timeout_duration)
}

/// Tries to await the future with a given timeout.
Expand Down Expand Up @@ -242,7 +248,7 @@ impl CassFuture {
let (guard_result, timeout_result) = self
.wait_for_value
.wait_timeout_while(guard, remaining_timeout, |state| {
state.value.is_none() && state.join_handle.is_none()
self.result.get().is_none() && state.join_handle.is_none()
})
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
Expand Down Expand Up @@ -275,7 +281,7 @@ impl CassFuture {
return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET;
}
let bound_cb = BoundCallback { cb, data };
if lock.value.is_some() {
if self.result.get().is_some() {
// The value is already available, we need to call the callback ourselves
mem::drop(lock);
bound_cb.invoke(self_ptr);
Expand Down Expand Up @@ -345,8 +351,7 @@ pub unsafe extern "C" fn cass_future_ready(
return cass_false;
};

let state_guard = future.state.lock().unwrap();
match state_guard.value {
match future.result.get() {
None => cass_false,
Some(_) => cass_true,
}
Expand All @@ -361,7 +366,7 @@ pub unsafe extern "C" fn cass_future_error_code(
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

future.with_waited_result(|r: &mut CassFutureResult| match r {
future.with_waited_result(|r: &CassFutureResult| match r {
Ok(CassResultValue::QueryError(err)) => err.to_cass_error(),
Err((err, _)) => *err,
_ => CassError::CASS_OK,
Expand All @@ -380,7 +385,7 @@ pub unsafe extern "C" fn cass_future_error_message(
};

future.with_waited_state(|state: &mut CassFutureState| {
let value = &state.value;
let value = future.result.get();
let msg = state
.err_string
.get_or_insert_with(|| match value.as_ref().unwrap() {
Expand All @@ -407,7 +412,7 @@ pub unsafe extern "C" fn cass_future_get_result(
};

future
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassResult>> {
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassResult>> {
match r.as_ref().ok()? {
CassResultValue::QueryResult(qr) => Some(Arc::clone(qr)),
_ => None,
Expand All @@ -426,7 +431,7 @@ pub unsafe extern "C" fn cass_future_get_error_result(
};

future
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassErrorResult>> {
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassErrorResult>> {
match r.as_ref().ok()? {
CassResultValue::QueryError(qr) => Some(Arc::clone(qr)),
_ => None,
Expand All @@ -445,7 +450,7 @@ pub unsafe extern "C" fn cass_future_get_prepared(
};

future
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassPrepared>> {
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassPrepared>> {
match r.as_ref().ok()? {
CassResultValue::Prepared(p) => Some(Arc::clone(p)),
_ => None,
Expand All @@ -464,7 +469,7 @@ pub unsafe extern "C" fn cass_future_tracing_id(
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

future.with_waited_result(|r: &mut CassFutureResult| match r {
future.with_waited_result(|r: &CassFutureResult| match r {
Ok(CassResultValue::QueryResult(result)) => match result.tracing_id {
Some(id) => {
unsafe { *tracing_id = CassUuid::from(id) };
Expand All @@ -476,6 +481,24 @@ pub unsafe extern "C" fn cass_future_tracing_id(
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn cass_future_coordinator(
future_raw: CassBorrowedSharedPtr<CassFuture, CMut>,
) -> CassBorrowedSharedPtr<CassNode, CConst> {
let Some(future) = ArcFFI::as_ref(future_raw) else {
tracing::error!("Provided null future to cass_future_coordinator!");
return RefFFI::null();
};

future.with_waited_result(|r| match r {
Ok(CassResultValue::QueryResult(result)) => {
// unwrap: Coordinator is `None` only for tests.
RefFFI::as_ptr(result.coordinator.as_ref().unwrap())
}
_ => RefFFI::null(),
})
}

#[cfg(test)]
mod tests {
use crate::testing::{assert_cass_error_eq, assert_cass_future_error_message_eq};
Expand Down
48 changes: 45 additions & 3 deletions scylla-rust-wrapper/src/integration_testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ use scylla::errors::{RequestAttemptError, RequestError};
use scylla::observability::history::{AttemptId, HistoryListener, RequestId, SpeculativeId};
use scylla::policies::retry::RetryDecision;

use crate::argconv::{BoxFFI, CMut, CassBorrowedExclusivePtr};
use crate::argconv::{
ArcFFI, BoxFFI, CConst, CMut, CassBorrowedExclusivePtr, CassBorrowedSharedPtr,
};
use crate::batch::CassBatch;
use crate::cluster::CassCluster;
use crate::future::{CassFuture, CassResultValue};
use crate::statement::{BoundStatement, CassStatement};
use crate::types::{cass_int32_t, cass_uint16_t, cass_uint64_t, size_t};

Expand Down Expand Up @@ -60,8 +63,47 @@ pub unsafe extern "C" fn testing_cluster_get_contact_points(
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn testing_free_contact_points(contact_points: *mut c_char) {
let _ = unsafe { CString::from_raw(contact_points) };
pub unsafe extern "C" fn testing_future_get_host(
future_raw: CassBorrowedSharedPtr<CassFuture, CConst>,
host: *mut *mut c_char,
host_length: *mut size_t,
) {
let Some(future) = ArcFFI::as_ref(future_raw) else {
tracing::error!("Provided null future pointer to testing_future_get_host!");
unsafe {
*host = std::ptr::null_mut();
*host_length = 0;
};
return;
};

future.with_waited_result(|r| match r {
Ok(CassResultValue::QueryResult(result)) => {
// unwrap: Coordinator is none only for unit tests.
let coordinator = result.coordinator.as_ref().unwrap();

let ip_addr_str = coordinator.node().address.ip().to_string();
let length = ip_addr_str.len();

let ip_addr_cstr = CString::new(ip_addr_str).expect(
"String obtained from IpAddr::to_string() should not contain any nul bytes!",
);

unsafe {
*host = ip_addr_cstr.into_raw();
*host_length = length as size_t
};
}
_ => unsafe {
*host = std::ptr::null_mut();
*host_length = 0;
},
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn testing_free_cstring(s: *mut c_char) {
let _ = unsafe { CString::from_raw(s) };
}

#[derive(Debug)]
Expand Down
Loading