Skip to content

Commit 3925fa0

Browse files
committed
future: store result in OnceLock
The result is going to be initialized only once, thus we do not need to store it behind a mutex. We can use OnceLock instead. Thanks to that, we can remove the unsafe logic which extends the lifetime of `coordinator` reference in cass_future_coordinator. We now guarantee that the result will be immutable once future is resolved - the guarantee is provided on the type-level.
1 parent 36f02ae commit 3925fa0

File tree

2 files changed

+61
-51
lines changed

2 files changed

+61
-51
lines changed

scylla-rust-wrapper/src/future.rs

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ use crate::query_result::{CassNode, CassResult};
99
use crate::types::*;
1010
use crate::uuid::CassUuid;
1111
use futures::future;
12-
use scylla::response::Coordinator;
1312
use std::future::Future;
1413
use std::mem;
1514
use std::os::raw::c_void;
16-
use std::sync::{Arc, Condvar, Mutex};
15+
use std::sync::{Arc, Condvar, Mutex, OnceLock};
1716
use tokio::task::JoinHandle;
1817
use tokio::time::Duration;
1918

19+
#[derive(Debug)]
2020
pub enum CassResultValue {
2121
Empty,
2222
QueryResult(Arc<CassResult>),
@@ -51,14 +51,14 @@ impl BoundCallback {
5151

5252
#[derive(Default)]
5353
struct CassFutureState {
54-
value: Option<CassFutureResult>,
5554
err_string: Option<String>,
5655
callback: Option<BoundCallback>,
5756
join_handle: Option<JoinHandle<()>>,
5857
}
5958

6059
pub struct CassFuture {
6160
state: Mutex<CassFutureState>,
61+
result: OnceLock<CassFutureResult>,
6262
wait_for_value: Condvar,
6363
}
6464

@@ -88,14 +88,18 @@ impl CassFuture {
8888
) -> Arc<CassFuture> {
8989
let cass_fut = Arc::new(CassFuture {
9090
state: Mutex::new(Default::default()),
91+
result: OnceLock::new(),
9192
wait_for_value: Condvar::new(),
9293
});
9394
let cass_fut_clone = Arc::clone(&cass_fut);
9495
let join_handle = RUNTIME.spawn(async move {
9596
let r = fut.await;
9697
let maybe_cb = {
9798
let mut guard = cass_fut_clone.state.lock().unwrap();
98-
guard.value = Some(r);
99+
cass_fut_clone
100+
.result
101+
.set(r)
102+
.expect("Tried to resolve future result twice!");
99103
// Take the callback and call it after releasing the lock
100104
guard.callback.take()
101105
};
@@ -116,16 +120,17 @@ impl CassFuture {
116120

117121
pub fn new_ready(r: CassFutureResult) -> Arc<Self> {
118122
Arc::new(CassFuture {
119-
state: Mutex::new(CassFutureState {
120-
value: Some(r),
121-
..Default::default()
122-
}),
123+
state: Mutex::new(CassFutureState::default()),
124+
result: OnceLock::from(r),
123125
wait_for_value: Condvar::new(),
124126
})
125127
}
126128

127-
pub fn with_waited_result<T>(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T {
128-
self.with_waited_state(|s| f(s.value.as_mut().unwrap()))
129+
pub fn with_waited_result<'s, T>(&'s self, f: impl FnOnce(&'s CassFutureResult) -> T) -> T
130+
where
131+
T: 's,
132+
{
133+
self.with_waited_state(|_| f(self.result.get().unwrap()))
129134
}
130135

131136
/// Awaits the future until completion.
@@ -154,7 +159,7 @@ impl CassFuture {
154159
guard = self
155160
.wait_for_value
156161
.wait_while(guard, |state| {
157-
state.value.is_none() && state.join_handle.is_none()
162+
self.result.get().is_none() && state.join_handle.is_none()
158163
})
159164
// unwrap: Error appears only when mutex is poisoned.
160165
.unwrap();
@@ -172,10 +177,10 @@ impl CassFuture {
172177

173178
fn with_waited_result_timed<T>(
174179
&self,
175-
f: impl FnOnce(&mut CassFutureResult) -> T,
180+
f: impl FnOnce(&CassFutureResult) -> T,
176181
timeout_duration: Duration,
177182
) -> Result<T, FutureError> {
178-
self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration)
183+
self.with_waited_state_timed(|_| f(self.result.get().unwrap()), timeout_duration)
179184
}
180185

181186
/// Tries to await the future with a given timeout.
@@ -243,7 +248,7 @@ impl CassFuture {
243248
let (guard_result, timeout_result) = self
244249
.wait_for_value
245250
.wait_timeout_while(guard, remaining_timeout, |state| {
246-
state.value.is_none() && state.join_handle.is_none()
251+
self.result.get().is_none() && state.join_handle.is_none()
247252
})
248253
// unwrap: Error appears only when mutex is poisoned.
249254
.unwrap();
@@ -276,7 +281,7 @@ impl CassFuture {
276281
return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET;
277282
}
278283
let bound_cb = BoundCallback { cb, data };
279-
if lock.value.is_some() {
284+
if self.result.get().is_some() {
280285
// The value is already available, we need to call the callback ourselves
281286
mem::drop(lock);
282287
bound_cb.invoke(self_ptr);
@@ -335,8 +340,12 @@ pub unsafe extern "C" fn cass_future_wait_timed(
335340
pub unsafe extern "C" fn cass_future_ready(
336341
future_raw: CassBorrowedSharedPtr<CassFuture, CMut>,
337342
) -> cass_bool_t {
338-
let state_guard = ArcFFI::as_ref(future_raw).unwrap().state.lock().unwrap();
339-
match state_guard.value {
343+
let Some(future) = ArcFFI::as_ref(future_raw) else {
344+
tracing::error!("Provided null future to cass_future_ready!");
345+
return cass_false;
346+
};
347+
348+
match future.result.get() {
340349
None => cass_false,
341350
Some(_) => cass_true,
342351
}
@@ -348,7 +357,7 @@ pub unsafe extern "C" fn cass_future_error_code(
348357
) -> CassError {
349358
ArcFFI::as_ref(future_raw)
350359
.unwrap()
351-
.with_waited_result(|r: &mut CassFutureResult| match r {
360+
.with_waited_result(|r: &CassFutureResult| match r {
352361
Ok(CassResultValue::QueryError(err)) => err.to_cass_error(),
353362
Err((err, _)) => *err,
354363
_ => CassError::CASS_OK,
@@ -361,19 +370,26 @@ pub unsafe extern "C" fn cass_future_error_message(
361370
message: *mut *const ::std::os::raw::c_char,
362371
message_length: *mut size_t,
363372
) {
364-
ArcFFI::as_ref(future)
365-
.unwrap()
366-
.with_waited_state(|state: &mut CassFutureState| {
367-
let value = &state.value;
368-
let msg = state
369-
.err_string
370-
.get_or_insert_with(|| match value.as_ref().unwrap() {
371-
Ok(CassResultValue::QueryError(err)) => err.msg(),
372-
Err((_, s)) => s.msg(),
373-
_ => "".to_string(),
374-
});
375-
unsafe { write_str_to_c(msg.as_str(), message, message_length) };
376-
});
373+
let Some(future) = ArcFFI::as_ref(future) else {
374+
tracing::error!("Provided null future to cass_future_error_message!");
375+
unsafe {
376+
*message = std::ptr::null();
377+
*message_length = 0;
378+
}
379+
return;
380+
};
381+
382+
future.with_waited_state(|state: &mut CassFutureState| {
383+
let value = future.result.get();
384+
let msg = state
385+
.err_string
386+
.get_or_insert_with(|| match value.as_ref().unwrap() {
387+
Ok(CassResultValue::QueryError(err)) => err.msg(),
388+
Err((_, s)) => s.msg(),
389+
_ => "".to_string(),
390+
});
391+
unsafe { write_str_to_c(msg.as_str(), message, message_length) };
392+
});
377393
}
378394

379395
#[unsafe(no_mangle)]
@@ -387,7 +403,7 @@ pub unsafe extern "C" fn cass_future_get_result(
387403
) -> CassOwnedSharedPtr<CassResult, CConst> {
388404
ArcFFI::as_ref(future_raw)
389405
.unwrap()
390-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassResult>> {
406+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassResult>> {
391407
match r.as_ref().ok()? {
392408
CassResultValue::QueryResult(qr) => Some(Arc::clone(qr)),
393409
_ => None,
@@ -402,7 +418,7 @@ pub unsafe extern "C" fn cass_future_get_error_result(
402418
) -> CassOwnedSharedPtr<CassErrorResult, CConst> {
403419
ArcFFI::as_ref(future_raw)
404420
.unwrap()
405-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassErrorResult>> {
421+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassErrorResult>> {
406422
match r.as_ref().ok()? {
407423
CassResultValue::QueryError(qr) => Some(Arc::clone(qr)),
408424
_ => None,
@@ -417,7 +433,7 @@ pub unsafe extern "C" fn cass_future_get_prepared(
417433
) -> CassOwnedSharedPtr<CassPrepared, CConst> {
418434
ArcFFI::as_ref(future_raw)
419435
.unwrap()
420-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassPrepared>> {
436+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassPrepared>> {
421437
match r.as_ref().ok()? {
422438
CassResultValue::Prepared(p) => Some(Arc::clone(p)),
423439
_ => None,
@@ -433,7 +449,7 @@ pub unsafe extern "C" fn cass_future_tracing_id(
433449
) -> CassError {
434450
ArcFFI::as_ref(future)
435451
.unwrap()
436-
.with_waited_result(|r: &mut CassFutureResult| match r {
452+
.with_waited_result(|r: &CassFutureResult| match r {
437453
Ok(CassResultValue::QueryResult(result)) => match result.tracing_id {
438454
Some(id) => {
439455
unsafe { *tracing_id = CassUuid::from(id) };
@@ -457,21 +473,7 @@ pub unsafe extern "C" fn cass_future_coordinator(
457473
future.with_waited_result(|r| match r {
458474
Ok(CassResultValue::QueryResult(result)) => {
459475
// unwrap: Coordinator is `None` only for tests.
460-
let coordinator_ptr = result.coordinator.as_ref().unwrap() as *const Coordinator;
461-
462-
// We need to 'extend' the lifetime of returned Coordinator so safe FFI api does not complain.
463-
// The lifetime of "result" reference provided to this closure is the lifetime of a mutex guard.
464-
// We are guaranteed, that once the future is resolved (i.e. this closure is called), the result will not
465-
// be modified in any way. Thus, we can guarantee that returned coordinator lives as long as underlying
466-
// CassResult lives (i.e. longer than the lifetime of acquired mutex guard).
467-
//
468-
// SAFETY: Coordinator's lifetime is tied to the lifetime of underlying CassResult, thus:
469-
// 1. Coordinator lives as long as the underlying CassResult lives
470-
// 2. Coordinator will not be moved as long as underlying CassResult is not freed
471-
// 3. Coordinator is immutable once future is resolved (because CassResult is set once)
472-
let coordinator_ref = unsafe { &*coordinator_ptr };
473-
474-
RefFFI::as_ptr(coordinator_ref)
476+
RefFFI::as_ptr(result.coordinator.as_ref().unwrap())
475477
}
476478
_ => RefFFI::null(),
477479
})

scylla-rust-wrapper/src/query_result.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@ use std::sync::Arc;
2929
use thiserror::Error;
3030
use uuid::Uuid;
3131

32+
#[derive(Debug)]
3233
pub enum CassResultKind {
3334
NonRows,
3435
Rows(CassRowsResult),
3536
}
3637

38+
#[derive(Debug)]
3739
pub struct CassRowsResult {
3840
// Arc: shared with first_row (yoke).
3941
pub(crate) shared_data: Arc<CassRowsResultSharedData>,
4042
pub(crate) first_row: Option<RowWithSelfBorrowedResultData>,
4143
}
4244

45+
#[derive(Debug)]
4346
pub(crate) struct CassRowsResultSharedData {
4447
pub(crate) raw_rows: DeserializedMetadataAndRawRows,
4548
// Arc: shared with CassPrepared
@@ -53,6 +56,7 @@ impl FFI for CassNode {
5356
type Origin = FromRef;
5457
}
5558

59+
#[derive(Debug)]
5660
pub struct CassResult {
5761
pub tracing_id: Option<Uuid>,
5862
pub paging_state_response: PagingStateResponse,
@@ -159,6 +163,7 @@ impl<'frame, 'metadata> DeserializeRow<'frame, 'metadata> for CassRawRow<'frame,
159163

160164
/// The lifetime of CassRow is bound to CassResult.
161165
/// It will be freed, when CassResult is freed.(see #[cass_result_free])
166+
#[derive(Debug)]
162167
pub struct CassRow<'result> {
163168
pub columns: Vec<CassValue<'result>>,
164169
pub result_metadata: &'result CassResultMetadata,
@@ -230,7 +235,7 @@ mod row_with_self_borrowed_result_data {
230235

231236
/// A simple wrapper over CassRow.
232237
/// Needed, so we can implement Yokeable for it, instead of implementing it for CassRow.
233-
#[derive(Yokeable)]
238+
#[derive(Yokeable, Debug)]
234239
struct CassRowWrapper<'result>(CassRow<'result>);
235240

236241
/// A wrapper over struct which self-borrows the metadata allocated using Arc.
@@ -243,6 +248,7 @@ mod row_with_self_borrowed_result_data {
243248
///
244249
/// This struct is a shared owner of the row bytes and metadata, and self-borrows this data
245250
/// to the `CassRow` it contains.
251+
#[derive(Debug)]
246252
pub struct RowWithSelfBorrowedResultData(
247253
Yoke<CassRowWrapper<'static>, Arc<CassRowsResultSharedData>>,
248254
);
@@ -307,6 +313,7 @@ pub(crate) mod cass_raw_value {
307313
use scylla::errors::{DeserializationError, TypeCheckError};
308314
use thiserror::Error;
309315

316+
#[derive(Debug)]
310317
pub(crate) struct CassRawValue<'frame, 'metadata> {
311318
typ: &'metadata ColumnType<'metadata>,
312319
slice: Option<FrameSlice<'frame>>,
@@ -428,6 +435,7 @@ pub(crate) mod cass_raw_value {
428435
}
429436
}
430437

438+
#[derive(Debug)]
431439
pub struct CassValue<'result> {
432440
pub(crate) value: CassRawValue<'result, 'result>,
433441
pub(crate) value_type: &'result Arc<CassDataType>,

0 commit comments

Comments
 (0)