Skip to content

Don't panic when null pointer is provided - log error and return instead #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions scylla-rust-wrapper/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ pub unsafe extern "C" fn cass_batch_set_consistency(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
consistency: CassConsistency,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_set_consistency!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let consistency = match consistency.try_into().ok() {
Some(c) => c,
None => return CassError::CASS_ERROR_LIB_BAD_PARAMS,
Expand All @@ -77,7 +81,11 @@ pub unsafe extern "C" fn cass_batch_set_serial_consistency(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
serial_consistency: CassConsistency,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_set_serial_consistency!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let serial_consistency = match serial_consistency.try_into().ok() {
Some(c) => c,
None => return CassError::CASS_ERROR_LIB_BAD_PARAMS,
Expand All @@ -94,7 +102,10 @@ pub unsafe extern "C" fn cass_batch_set_retry_policy(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
retry_policy: CassBorrowedSharedPtr<CassRetryPolicy, CMut>,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_set_retry_policy!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let maybe_arced_retry_policy: Option<Arc<dyn scylla::policies::retry::RetryPolicy>> =
ArcFFI::as_ref(retry_policy).map(|policy| match policy {
Expand All @@ -117,7 +128,10 @@ pub unsafe extern "C" fn cass_batch_set_timestamp(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
timestamp: cass_int64_t,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_set_timestamp!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

Arc::make_mut(&mut batch.state)
.batch
Expand All @@ -131,7 +145,10 @@ pub unsafe extern "C" fn cass_batch_set_request_timeout(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
timeout_ms: cass_uint64_t,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_set_request_timeout!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
batch.batch_request_timeout_ms = Some(timeout_ms);

CassError::CASS_OK
Expand All @@ -142,7 +159,11 @@ pub unsafe extern "C" fn cass_batch_set_is_idempotent(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
is_idempotent: cass_bool_t,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_set_is_idempotent!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

Arc::make_mut(&mut batch.state)
.batch
.set_is_idempotent(is_idempotent != 0);
Expand All @@ -155,7 +176,11 @@ pub unsafe extern "C" fn cass_batch_set_tracing(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
enabled: cass_bool_t,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_set_tracing!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

Arc::make_mut(&mut batch.state)
.batch
.set_tracing(enabled != 0);
Expand All @@ -168,9 +193,16 @@ pub unsafe extern "C" fn cass_batch_add_statement(
batch: CassBorrowedExclusivePtr<CassBatch, CMut>,
statement: CassBorrowedSharedPtr<CassStatement, CMut>,
) -> CassError {
let batch = BoxFFI::as_mut_ref(batch).unwrap();
let Some(batch) = BoxFFI::as_mut_ref(batch) else {
tracing::error!("Provided null batch pointer to cass_batch_add_statement!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
let Some(statement) = BoxFFI::as_ref(statement) else {
tracing::error!("Provided null statement pointer to cass_batch_add_statement!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let state = Arc::make_mut(&mut batch.state);
let statement = BoxFFI::as_ref(statement).unwrap();

match &statement.statement {
BoundStatement::Simple(q) => {
Expand Down
24 changes: 20 additions & 4 deletions scylla-rust-wrapper/src/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ macro_rules! make_index_binder {
// For some reason detected as unused, which is not true
#[allow(unused_imports)]
use crate::value::CassCqlValue::*;
let Some(this) = BoxFFI::as_mut_ref(this) else {
tracing::error!("Provided null pointer to {}!", stringify!($fn_by_idx));
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
match ($e)($($arg), *) {
Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), index as usize, v),
Ok(v) => $consume_v(this, index as usize, v),
Err(e) => e,
}
}
Expand All @@ -80,9 +84,13 @@ macro_rules! make_name_binder {
// For some reason detected as unused, which is not true
#[allow(unused_imports)]
use crate::value::CassCqlValue::*;
let Some(this) = BoxFFI::as_mut_ref(this) else {
tracing::error!("Provided null pointer to {}!", stringify!($fn_by_name));
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
let name = unsafe { ptr_to_cstr(name) }.unwrap();
match ($e)($($arg), *) {
Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), name, v),
Ok(v) => $consume_v(this, name, v),
Err(e) => e,
}
}
Expand All @@ -102,9 +110,13 @@ macro_rules! make_name_n_binder {
// For some reason detected as unused, which is not true
#[allow(unused_imports)]
use crate::value::CassCqlValue::*;
let Some(this) = BoxFFI::as_mut_ref(this) else {
tracing::error!("Provided null pointer to {}!", stringify!($fn_by_name_n));
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
let name = unsafe { ptr_to_cstr_n(name, name_length) }.unwrap();
match ($e)($($arg), *) {
Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), name, v),
Ok(v) => $consume_v(this, name, v),
Err(e) => e,
}
}
Expand All @@ -122,8 +134,12 @@ macro_rules! make_appender {
// For some reason detected as unused, which is not true
#[allow(unused_imports)]
use crate::value::CassCqlValue::*;
let Some(this) = BoxFFI::as_mut_ref(this) else {
tracing::error!("Provided null pointer to {}!", stringify!($fn_append));
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
match ($e)($($arg), *) {
Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), v),
Ok(v) => $consume_v(this, v),
Err(e) => e,
}
}
Expand Down
109 changes: 90 additions & 19 deletions scylla-rust-wrapper/src/cass_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,11 @@ pub unsafe extern "C" fn cass_data_type_new(
pub unsafe extern "C" fn cass_data_type_new_from_existing(
data_type: CassBorrowedSharedPtr<CassDataType, CConst>,
) -> CassOwnedSharedPtr<CassDataType, CMut> {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_new_from_existing!");
return ArcFFI::null();
};

ArcFFI::into_ptr(CassDataType::new_arced(
unsafe { data_type.get_unchecked() }.clone(),
))
Expand Down Expand Up @@ -507,15 +511,23 @@ pub unsafe extern "C" fn cass_data_type_free(data_type: CassOwnedSharedPtr<CassD
pub unsafe extern "C" fn cass_data_type_type(
data_type: CassBorrowedSharedPtr<CassDataType, CConst>,
) -> CassValueType {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_type!");
return CassValueType::CASS_VALUE_TYPE_UNKNOWN;
};

unsafe { data_type.get_unchecked() }.get_value_type()
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn cass_data_type_is_frozen(
data_type: CassBorrowedSharedPtr<CassDataType, CConst>,
) -> cass_bool_t {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_is_frozen!");
return cass_false;
};

let is_frozen = match unsafe { data_type.get_unchecked() } {
CassDataTypeInner::UDT(udt) => udt.frozen,
CassDataTypeInner::List { frozen, .. } => *frozen,
Expand All @@ -533,7 +545,11 @@ pub unsafe extern "C" fn cass_data_type_type_name(
type_name: *mut *const c_char,
type_name_length: *mut size_t,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_type_name!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

match unsafe { data_type.get_unchecked() } {
CassDataTypeInner::UDT(UDTDataType { name, .. }) => {
unsafe { write_str_to_c(name, type_name, type_name_length) };
Expand All @@ -557,7 +573,11 @@ pub unsafe extern "C" fn cass_data_type_set_type_name_n(
type_name: *const c_char,
type_name_length: size_t,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type_raw).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type_raw) else {
tracing::error!("Provided null data type pointer to cass_data_type_set_type_name_n!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let type_name_string = unsafe { ptr_to_cstr_n(type_name, type_name_length) }
.unwrap()
.to_string();
Expand All @@ -577,7 +597,11 @@ pub unsafe extern "C" fn cass_data_type_keyspace(
keyspace: *mut *const c_char,
keyspace_length: *mut size_t,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_keyspace!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

match unsafe { data_type.get_unchecked() } {
CassDataTypeInner::UDT(UDTDataType { name, .. }) => {
unsafe { write_str_to_c(name, keyspace, keyspace_length) };
Expand All @@ -601,7 +625,11 @@ pub unsafe extern "C" fn cass_data_type_set_keyspace_n(
keyspace: *const c_char,
keyspace_length: size_t,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_set_keyspace_n!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let keyspace_string = unsafe { ptr_to_cstr_n(keyspace, keyspace_length) }
.unwrap()
.to_string();
Expand All @@ -621,7 +649,11 @@ pub unsafe extern "C" fn cass_data_type_class_name(
class_name: *mut *const ::std::os::raw::c_char,
class_name_length: *mut size_t,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_class_name!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

match unsafe { data_type.get_unchecked() } {
CassDataTypeInner::Custom(name) => {
unsafe { write_str_to_c(name, class_name, class_name_length) };
Expand All @@ -645,7 +677,11 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n(
class_name: *const ::std::os::raw::c_char,
class_name_length: size_t,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_set_class_name_n!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let class_string = unsafe { ptr_to_cstr_n(class_name, class_name_length) }
.unwrap()
.to_string();
Expand All @@ -662,7 +698,11 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n(
pub unsafe extern "C" fn cass_data_type_sub_type_count(
data_type: CassBorrowedSharedPtr<CassDataType, CConst>,
) -> size_t {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_sub_type_count!");
return 0;
};

match unsafe { data_type.get_unchecked() } {
CassDataTypeInner::Value(..) => 0,
CassDataTypeInner::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t,
Expand Down Expand Up @@ -691,7 +731,11 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type(
data_type: CassBorrowedSharedPtr<CassDataType, CConst>,
index: size_t,
) -> CassBorrowedSharedPtr<CassDataType, CConst> {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_sub_data_type!");
return ArcFFI::null();
};

let sub_type: Option<&Arc<CassDataType>> =
unsafe { data_type.get_unchecked() }.get_sub_data_type(index as usize);

Expand All @@ -716,7 +760,13 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name_n(
name: *const ::std::os::raw::c_char,
name_length: size_t,
) -> CassBorrowedSharedPtr<CassDataType, CConst> {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!(
"Provided null data type pointer to cass_data_type_sub_data_type_by_name_n!"
);
return ArcFFI::null();
};

let name_str = unsafe { ptr_to_cstr_n(name, name_length) }.unwrap();
match unsafe { data_type.get_unchecked() } {
CassDataTypeInner::UDT(udt) => match udt.get_field_by_name(name_str) {
Expand All @@ -734,7 +784,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_name(
name: *mut *const ::std::os::raw::c_char,
name_length: *mut size_t,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type).unwrap();
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_sub_type_name!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

match unsafe { data_type.get_unchecked() } {
CassDataTypeInner::UDT(udt) => match udt.field_types.get(index as usize) {
None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS,
Expand All @@ -752,10 +806,16 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type(
data_type: CassBorrowedSharedPtr<CassDataType, CMut>,
sub_data_type: CassBorrowedSharedPtr<CassDataType, CConst>,
) -> CassError {
let data_type = ArcFFI::as_ref(data_type).unwrap();
match unsafe { data_type.get_mut_unchecked() }
.add_sub_data_type(ArcFFI::cloned_from_ptr(sub_data_type).unwrap())
{
let Some(data_type) = ArcFFI::as_ref(data_type) else {
tracing::error!("Provided null data type pointer to cass_data_type_add_sub_type!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type) else {
tracing::error!("Provided null sub data type pointer to cass_data_type_add_sub_type!");
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

match unsafe { data_type.get_mut_unchecked() }.add_sub_data_type(sub_data_type) {
Ok(()) => CassError::CASS_OK,
Err(e) => e,
}
Expand All @@ -777,12 +837,23 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n(
name_length: size_t,
sub_data_type_raw: CassBorrowedSharedPtr<CassDataType, CConst>,
) -> CassError {
let Some(data_type) = ArcFFI::as_ref(data_type_raw) else {
tracing::error!(
"Provided null data type pointer to cass_data_type_add_sub_type_by_name_n!"
);
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};
let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type_raw) else {
tracing::error!(
"Provided null sub data type pointer to cass_data_type_add_sub_type_by_name_n!"
);
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
};

let name_string = unsafe { ptr_to_cstr_n(name, name_length) }
.unwrap()
.to_string();
let sub_data_type = ArcFFI::cloned_from_ptr(sub_data_type_raw).unwrap();

let data_type = ArcFFI::as_ref(data_type_raw).unwrap();
match unsafe { data_type.get_mut_unchecked() } {
CassDataTypeInner::UDT(udt_data_type) => {
// The Cpp Driver does not check whether field_types size
Expand Down
Loading