Skip to content

Commit 665266c

Browse files
committed
future: store result in OnceLock
The result is going to be initialized only once, thus we do not need to store it behind a mutex. We can use OnceLock instead. Thanks to that, we can remove the unsafe logic which extends the lifetime of `coordinator` reference in cass_future_coordinator. We now guarantee that the result will be immutable once future is resolved - the guarantee is provided on the type-level.
1 parent d4b32e5 commit 665266c

File tree

2 files changed

+37
-39
lines changed

2 files changed

+37
-39
lines changed

scylla-rust-wrapper/src/future.rs

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ use crate::query_result::{CassNode, CassResult};
99
use crate::types::*;
1010
use crate::uuid::CassUuid;
1111
use futures::future;
12-
use scylla::response::Coordinator;
1312
use std::future::Future;
1413
use std::mem;
1514
use std::os::raw::c_void;
16-
use std::sync::{Arc, Condvar, Mutex};
15+
use std::sync::{Arc, Condvar, Mutex, OnceLock};
1716
use tokio::task::JoinHandle;
1817
use tokio::time::Duration;
1918

19+
#[derive(Debug)]
2020
pub enum CassResultValue {
2121
Empty,
2222
QueryResult(Arc<CassResult>),
@@ -51,14 +51,14 @@ impl BoundCallback {
5151

5252
#[derive(Default)]
5353
struct CassFutureState {
54-
value: Option<CassFutureResult>,
5554
err_string: Option<String>,
5655
callback: Option<BoundCallback>,
5756
join_handle: Option<JoinHandle<()>>,
5857
}
5958

6059
pub struct CassFuture {
6160
state: Mutex<CassFutureState>,
61+
result: OnceLock<CassFutureResult>,
6262
wait_for_value: Condvar,
6363
}
6464

@@ -88,14 +88,18 @@ impl CassFuture {
8888
) -> Arc<CassFuture> {
8989
let cass_fut = Arc::new(CassFuture {
9090
state: Mutex::new(Default::default()),
91+
result: OnceLock::new(),
9192
wait_for_value: Condvar::new(),
9293
});
9394
let cass_fut_clone = Arc::clone(&cass_fut);
9495
let join_handle = RUNTIME.spawn(async move {
9596
let r = fut.await;
9697
let maybe_cb = {
9798
let mut guard = cass_fut_clone.state.lock().unwrap();
98-
guard.value = Some(r);
99+
cass_fut_clone
100+
.result
101+
.set(r)
102+
.expect("Tried to resolve future result twice!");
99103
// Take the callback and call it after releasing the lock
100104
guard.callback.take()
101105
};
@@ -116,16 +120,17 @@ impl CassFuture {
116120

117121
pub fn new_ready(r: CassFutureResult) -> Arc<Self> {
118122
Arc::new(CassFuture {
119-
state: Mutex::new(CassFutureState {
120-
value: Some(r),
121-
..Default::default()
122-
}),
123+
state: Mutex::new(CassFutureState::default()),
124+
result: OnceLock::from(r),
123125
wait_for_value: Condvar::new(),
124126
})
125127
}
126128

127-
pub fn with_waited_result<T>(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T {
128-
self.with_waited_state(|s| f(s.value.as_mut().unwrap()))
129+
pub fn with_waited_result<'s, T>(&'s self, f: impl FnOnce(&'s CassFutureResult) -> T) -> T
130+
where
131+
T: 's,
132+
{
133+
self.with_waited_state(|_| f(self.result.get().unwrap()))
129134
}
130135

131136
/// Awaits the future until completion.
@@ -154,7 +159,7 @@ impl CassFuture {
154159
guard = self
155160
.wait_for_value
156161
.wait_while(guard, |state| {
157-
state.value.is_none() && state.join_handle.is_none()
162+
self.result.get().is_none() && state.join_handle.is_none()
158163
})
159164
// unwrap: Error appears only when mutex is poisoned.
160165
.unwrap();
@@ -172,10 +177,10 @@ impl CassFuture {
172177

173178
fn with_waited_result_timed<T>(
174179
&self,
175-
f: impl FnOnce(&mut CassFutureResult) -> T,
180+
f: impl FnOnce(&CassFutureResult) -> T,
176181
timeout_duration: Duration,
177182
) -> Result<T, FutureError> {
178-
self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration)
183+
self.with_waited_state_timed(|_| f(self.result.get().unwrap()), timeout_duration)
179184
}
180185

181186
/// Tries to await the future with a given timeout.
@@ -243,7 +248,7 @@ impl CassFuture {
243248
let (guard_result, timeout_result) = self
244249
.wait_for_value
245250
.wait_timeout_while(guard, remaining_timeout, |state| {
246-
state.value.is_none() && state.join_handle.is_none()
251+
self.result.get().is_none() && state.join_handle.is_none()
247252
})
248253
// unwrap: Error appears only when mutex is poisoned.
249254
.unwrap();
@@ -276,7 +281,7 @@ impl CassFuture {
276281
return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET;
277282
}
278283
let bound_cb = BoundCallback { cb, data };
279-
if lock.value.is_some() {
284+
if self.result.get().is_some() {
280285
// The value is already available, we need to call the callback ourselves
281286
mem::drop(lock);
282287
bound_cb.invoke(self_ptr);
@@ -346,8 +351,7 @@ pub unsafe extern "C" fn cass_future_ready(
346351
return cass_false;
347352
};
348353

349-
let state_guard = future.state.lock().unwrap();
350-
match state_guard.value {
354+
match future.result.get() {
351355
None => cass_false,
352356
Some(_) => cass_true,
353357
}
@@ -362,7 +366,7 @@ pub unsafe extern "C" fn cass_future_error_code(
362366
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
363367
};
364368

365-
future.with_waited_result(|r: &mut CassFutureResult| match r {
369+
future.with_waited_result(|r: &CassFutureResult| match r {
366370
Ok(CassResultValue::QueryError(err)) => err.to_cass_error(),
367371
Err((err, _)) => *err,
368372
_ => CassError::CASS_OK,
@@ -381,7 +385,7 @@ pub unsafe extern "C" fn cass_future_error_message(
381385
};
382386

383387
future.with_waited_state(|state: &mut CassFutureState| {
384-
let value = &state.value;
388+
let value = future.result.get();
385389
let msg = state
386390
.err_string
387391
.get_or_insert_with(|| match value.as_ref().unwrap() {
@@ -408,7 +412,7 @@ pub unsafe extern "C" fn cass_future_get_result(
408412
};
409413

410414
future
411-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassResult>> {
415+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassResult>> {
412416
match r.as_ref().ok()? {
413417
CassResultValue::QueryResult(qr) => Some(Arc::clone(qr)),
414418
_ => None,
@@ -427,7 +431,7 @@ pub unsafe extern "C" fn cass_future_get_error_result(
427431
};
428432

429433
future
430-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassErrorResult>> {
434+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassErrorResult>> {
431435
match r.as_ref().ok()? {
432436
CassResultValue::QueryError(qr) => Some(Arc::clone(qr)),
433437
_ => None,
@@ -446,7 +450,7 @@ pub unsafe extern "C" fn cass_future_get_prepared(
446450
};
447451

448452
future
449-
.with_waited_result(|r: &mut CassFutureResult| -> Option<Arc<CassPrepared>> {
453+
.with_waited_result(|r: &CassFutureResult| -> Option<Arc<CassPrepared>> {
450454
match r.as_ref().ok()? {
451455
CassResultValue::Prepared(p) => Some(Arc::clone(p)),
452456
_ => None,
@@ -465,7 +469,7 @@ pub unsafe extern "C" fn cass_future_tracing_id(
465469
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
466470
};
467471

468-
future.with_waited_result(|r: &mut CassFutureResult| match r {
472+
future.with_waited_result(|r: &CassFutureResult| match r {
469473
Ok(CassResultValue::QueryResult(result)) => match result.tracing_id {
470474
Some(id) => {
471475
unsafe { *tracing_id = CassUuid::from(id) };
@@ -489,21 +493,7 @@ pub unsafe extern "C" fn cass_future_coordinator(
489493
future.with_waited_result(|r| match r {
490494
Ok(CassResultValue::QueryResult(result)) => {
491495
// unwrap: Coordinator is `None` only for tests.
492-
let coordinator_ptr = result.coordinator.as_ref().unwrap() as *const Coordinator;
493-
494-
// We need to 'extend' the lifetime of returned Coordinator so safe FFI api does not complain.
495-
// The lifetime of "result" reference provided to this closure is the lifetime of a mutex guard.
496-
// We are guaranteed, that once the future is resolved (i.e. this closure is called), the result will not
497-
// be modified in any way. Thus, we can guarantee that returned coordinator lives as long as underlying
498-
// CassResult lives (i.e. longer than the lifetime of acquired mutex guard).
499-
//
500-
// SAFETY: Coordinator's lifetime is tied to the lifetime of underlying CassResult, thus:
501-
// 1. Coordinator lives as long as the underlying CassResult lives
502-
// 2. Coordinator will not be moved as long as underlying CassResult is not freed
503-
// 3. Coordinator is immutable once future is resolved (because CassResult is set once)
504-
let coordinator_ref = unsafe { &*coordinator_ptr };
505-
506-
RefFFI::as_ptr(coordinator_ref)
496+
RefFFI::as_ptr(result.coordinator.as_ref().unwrap())
507497
}
508498
_ => RefFFI::null(),
509499
})

scylla-rust-wrapper/src/query_result.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,20 @@ use std::sync::Arc;
2929
use thiserror::Error;
3030
use uuid::Uuid;
3131

32+
#[derive(Debug)]
3233
pub enum CassResultKind {
3334
NonRows,
3435
Rows(CassRowsResult),
3536
}
3637

38+
#[derive(Debug)]
3739
pub struct CassRowsResult {
3840
// Arc: shared with first_row (yoke).
3941
pub(crate) shared_data: Arc<CassRowsResultSharedData>,
4042
pub(crate) first_row: Option<RowWithSelfBorrowedResultData>,
4143
}
4244

45+
#[derive(Debug)]
4346
pub(crate) struct CassRowsResultSharedData {
4447
pub(crate) raw_rows: DeserializedMetadataAndRawRows,
4548
// Arc: shared with CassPrepared
@@ -53,6 +56,7 @@ impl FFI for CassNode {
5356
type Origin = FromRef;
5457
}
5558

59+
#[derive(Debug)]
5660
pub struct CassResult {
5761
pub tracing_id: Option<Uuid>,
5862
pub paging_state_response: PagingStateResponse,
@@ -159,6 +163,7 @@ impl<'frame, 'metadata> DeserializeRow<'frame, 'metadata> for CassRawRow<'frame,
159163

160164
/// The lifetime of CassRow is bound to CassResult.
161165
/// It will be freed, when CassResult is freed.(see #[cass_result_free])
166+
#[derive(Debug)]
162167
pub struct CassRow<'result> {
163168
pub columns: Vec<CassValue<'result>>,
164169
pub result_metadata: &'result CassResultMetadata,
@@ -230,7 +235,7 @@ mod row_with_self_borrowed_result_data {
230235

231236
/// A simple wrapper over CassRow.
232237
/// Needed, so we can implement Yokeable for it, instead of implementing it for CassRow.
233-
#[derive(Yokeable)]
238+
#[derive(Yokeable, Debug)]
234239
struct CassRowWrapper<'result>(CassRow<'result>);
235240

236241
/// A wrapper over struct which self-borrows the metadata allocated using Arc.
@@ -243,6 +248,7 @@ mod row_with_self_borrowed_result_data {
243248
///
244249
/// This struct is a shared owner of the row bytes and metadata, and self-borrows this data
245250
/// to the `CassRow` it contains.
251+
#[derive(Debug)]
246252
pub struct RowWithSelfBorrowedResultData(
247253
Yoke<CassRowWrapper<'static>, Arc<CassRowsResultSharedData>>,
248254
);
@@ -307,6 +313,7 @@ pub(crate) mod cass_raw_value {
307313
use scylla::errors::{DeserializationError, TypeCheckError};
308314
use thiserror::Error;
309315

316+
#[derive(Debug)]
310317
pub(crate) struct CassRawValue<'frame, 'metadata> {
311318
typ: &'metadata ColumnType<'metadata>,
312319
slice: Option<FrameSlice<'frame>>,
@@ -428,6 +435,7 @@ pub(crate) mod cass_raw_value {
428435
}
429436
}
430437

438+
#[derive(Debug)]
431439
pub struct CassValue<'result> {
432440
pub(crate) value: CassRawValue<'result, 'result>,
433441
pub(crate) value_type: &'result Arc<CassDataType>,

0 commit comments

Comments
 (0)