diff --git a/scylla-rust-wrapper/src/batch.rs b/scylla-rust-wrapper/src/batch.rs index 7968c84d..b777deae 100644 --- a/scylla-rust-wrapper/src/batch.rs +++ b/scylla-rust-wrapper/src/batch.rs @@ -60,7 +60,11 @@ pub unsafe extern "C" fn cass_batch_set_consistency( batch: CassBorrowedExclusivePtr, consistency: CassConsistency, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_consistency!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let consistency = match consistency.try_into().ok() { Some(c) => c, None => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -77,7 +81,11 @@ pub unsafe extern "C" fn cass_batch_set_serial_consistency( batch: CassBorrowedExclusivePtr, serial_consistency: CassConsistency, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_serial_consistency!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let serial_consistency = match serial_consistency.try_into().ok() { Some(c) => c, None => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -94,7 +102,10 @@ pub unsafe extern "C" fn cass_batch_set_retry_policy( batch: CassBorrowedExclusivePtr, retry_policy: CassBorrowedSharedPtr, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_retry_policy!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; let maybe_arced_retry_policy: Option> = ArcFFI::as_ref(retry_policy).map(|policy| match policy { @@ -117,7 +128,10 @@ pub unsafe extern "C" fn cass_batch_set_timestamp( batch: CassBorrowedExclusivePtr, timestamp: cass_int64_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_timestamp!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; Arc::make_mut(&mut batch.state) .batch @@ -131,7 +145,10 @@ pub unsafe extern "C" fn cass_batch_set_request_timeout( batch: CassBorrowedExclusivePtr, timeout_ms: cass_uint64_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_request_timeout!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; batch.batch_request_timeout_ms = Some(timeout_ms); CassError::CASS_OK @@ -142,7 +159,11 @@ pub unsafe extern "C" fn cass_batch_set_is_idempotent( batch: CassBorrowedExclusivePtr, is_idempotent: cass_bool_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_is_idempotent!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + Arc::make_mut(&mut batch.state) .batch .set_is_idempotent(is_idempotent != 0); @@ -155,7 +176,11 @@ pub unsafe extern "C" fn cass_batch_set_tracing( batch: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_tracing!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + Arc::make_mut(&mut batch.state) .batch .set_tracing(enabled != 0); @@ -168,9 +193,16 @@ pub unsafe extern "C" fn cass_batch_add_statement( batch: CassBorrowedExclusivePtr, statement: CassBorrowedSharedPtr, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_add_statement!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let Some(statement) = BoxFFI::as_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_batch_add_statement!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let state = Arc::make_mut(&mut batch.state); - let statement = BoxFFI::as_ref(statement).unwrap(); match &statement.statement { BoundStatement::Simple(q) => { diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index 47f91ba9..38bc4554 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -60,8 +60,12 @@ macro_rules! make_index_binder { // For some reason detected as unused, which is not true #[allow(unused_imports)] use crate::value::CassCqlValue::*; + let Some(this) = BoxFFI::as_mut_ref(this) else { + tracing::error!("Provided null pointer to {}!", stringify!($fn_by_idx)); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), index as usize, v), + Ok(v) => $consume_v(this, index as usize, v), Err(e) => e, } } @@ -80,9 +84,13 @@ macro_rules! make_name_binder { // For some reason detected as unused, which is not true #[allow(unused_imports)] use crate::value::CassCqlValue::*; + let Some(this) = BoxFFI::as_mut_ref(this) else { + tracing::error!("Provided null pointer to {}!", stringify!($fn_by_name)); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; let name = unsafe { ptr_to_cstr(name) }.unwrap(); match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), name, v), + Ok(v) => $consume_v(this, name, v), Err(e) => e, } } @@ -102,9 +110,13 @@ macro_rules! make_name_n_binder { // For some reason detected as unused, which is not true #[allow(unused_imports)] use crate::value::CassCqlValue::*; + let Some(this) = BoxFFI::as_mut_ref(this) else { + tracing::error!("Provided null pointer to {}!", stringify!($fn_by_name_n)); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; let name = unsafe { ptr_to_cstr_n(name, name_length) }.unwrap(); match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), name, v), + Ok(v) => $consume_v(this, name, v), Err(e) => e, } } @@ -122,8 +134,12 @@ macro_rules! make_appender { // For some reason detected as unused, which is not true #[allow(unused_imports)] use crate::value::CassCqlValue::*; + let Some(this) = BoxFFI::as_mut_ref(this) else { + tracing::error!("Provided null pointer to {}!", stringify!($fn_append)); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; match ($e)($($arg), *) { - Ok(v) => $consume_v(BoxFFI::as_mut_ref(this).unwrap(), v), + Ok(v) => $consume_v(this, v), Err(e) => e, } } diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index da27e6d2..2dc3ade2 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -474,7 +474,11 @@ pub unsafe extern "C" fn cass_data_type_new( pub unsafe extern "C" fn cass_data_type_new_from_existing( data_type: CassBorrowedSharedPtr, ) -> CassOwnedSharedPtr { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_new_from_existing!"); + return ArcFFI::null(); + }; + ArcFFI::into_ptr(CassDataType::new_arced( unsafe { data_type.get_unchecked() }.clone(), )) @@ -507,7 +511,11 @@ pub unsafe extern "C" fn cass_data_type_free(data_type: CassOwnedSharedPtr, ) -> CassValueType { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_type!"); + return CassValueType::CASS_VALUE_TYPE_UNKNOWN; + }; + unsafe { data_type.get_unchecked() }.get_value_type() } @@ -515,7 +523,11 @@ pub unsafe extern "C" fn cass_data_type_type( pub unsafe extern "C" fn cass_data_type_is_frozen( data_type: CassBorrowedSharedPtr, ) -> cass_bool_t { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_is_frozen!"); + return cass_false; + }; + let is_frozen = match unsafe { data_type.get_unchecked() } { CassDataTypeInner::UDT(udt) => udt.frozen, CassDataTypeInner::List { frozen, .. } => *frozen, @@ -533,7 +545,11 @@ pub unsafe extern "C" fn cass_data_type_type_name( type_name: *mut *const c_char, type_name_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_type_name!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match unsafe { data_type.get_unchecked() } { CassDataTypeInner::UDT(UDTDataType { name, .. }) => { unsafe { write_str_to_c(name, type_name, type_name_length) }; @@ -557,7 +573,11 @@ pub unsafe extern "C" fn cass_data_type_set_type_name_n( type_name: *const c_char, type_name_length: size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type_raw).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type_raw) else { + tracing::error!("Provided null data type pointer to cass_data_type_set_type_name_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let type_name_string = unsafe { ptr_to_cstr_n(type_name, type_name_length) } .unwrap() .to_string(); @@ -577,7 +597,11 @@ pub unsafe extern "C" fn cass_data_type_keyspace( keyspace: *mut *const c_char, keyspace_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_keyspace!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match unsafe { data_type.get_unchecked() } { CassDataTypeInner::UDT(UDTDataType { name, .. }) => { unsafe { write_str_to_c(name, keyspace, keyspace_length) }; @@ -601,7 +625,11 @@ pub unsafe extern "C" fn cass_data_type_set_keyspace_n( keyspace: *const c_char, keyspace_length: size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_set_keyspace_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let keyspace_string = unsafe { ptr_to_cstr_n(keyspace, keyspace_length) } .unwrap() .to_string(); @@ -621,7 +649,11 @@ pub unsafe extern "C" fn cass_data_type_class_name( class_name: *mut *const ::std::os::raw::c_char, class_name_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_class_name!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match unsafe { data_type.get_unchecked() } { CassDataTypeInner::Custom(name) => { unsafe { write_str_to_c(name, class_name, class_name_length) }; @@ -645,7 +677,11 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n( class_name: *const ::std::os::raw::c_char, class_name_length: size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_set_class_name_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let class_string = unsafe { ptr_to_cstr_n(class_name, class_name_length) } .unwrap() .to_string(); @@ -662,7 +698,11 @@ pub unsafe extern "C" fn cass_data_type_set_class_name_n( pub unsafe extern "C" fn cass_data_type_sub_type_count( data_type: CassBorrowedSharedPtr, ) -> size_t { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_sub_type_count!"); + return 0; + }; + match unsafe { data_type.get_unchecked() } { CassDataTypeInner::Value(..) => 0, CassDataTypeInner::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t, @@ -691,7 +731,11 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type( data_type: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_sub_data_type!"); + return ArcFFI::null(); + }; + let sub_type: Option<&Arc> = unsafe { data_type.get_unchecked() }.get_sub_data_type(index as usize); @@ -716,7 +760,13 @@ pub unsafe extern "C" fn cass_data_type_sub_data_type_by_name_n( name: *const ::std::os::raw::c_char, name_length: size_t, ) -> CassBorrowedSharedPtr { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!( + "Provided null data type pointer to cass_data_type_sub_data_type_by_name_n!" + ); + return ArcFFI::null(); + }; + let name_str = unsafe { ptr_to_cstr_n(name, name_length) }.unwrap(); match unsafe { data_type.get_unchecked() } { CassDataTypeInner::UDT(udt) => match udt.get_field_by_name(name_str) { @@ -734,7 +784,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_name( name: *mut *const ::std::os::raw::c_char, name_length: *mut size_t, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type).unwrap(); + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_sub_type_name!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match unsafe { data_type.get_unchecked() } { CassDataTypeInner::UDT(udt) => match udt.field_types.get(index as usize) { None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS, @@ -752,10 +806,16 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type( data_type: CassBorrowedSharedPtr, sub_data_type: CassBorrowedSharedPtr, ) -> CassError { - let data_type = ArcFFI::as_ref(data_type).unwrap(); - match unsafe { data_type.get_mut_unchecked() } - .add_sub_data_type(ArcFFI::cloned_from_ptr(sub_data_type).unwrap()) - { + let Some(data_type) = ArcFFI::as_ref(data_type) else { + tracing::error!("Provided null data type pointer to cass_data_type_add_sub_type!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type) else { + tracing::error!("Provided null sub data type pointer to cass_data_type_add_sub_type!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match unsafe { data_type.get_mut_unchecked() }.add_sub_data_type(sub_data_type) { Ok(()) => CassError::CASS_OK, Err(e) => e, } @@ -777,12 +837,23 @@ pub unsafe extern "C" fn cass_data_type_add_sub_type_by_name_n( name_length: size_t, sub_data_type_raw: CassBorrowedSharedPtr, ) -> CassError { + let Some(data_type) = ArcFFI::as_ref(data_type_raw) else { + tracing::error!( + "Provided null data type pointer to cass_data_type_add_sub_type_by_name_n!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let Some(sub_data_type) = ArcFFI::cloned_from_ptr(sub_data_type_raw) else { + tracing::error!( + "Provided null sub data type pointer to cass_data_type_add_sub_type_by_name_n!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let name_string = unsafe { ptr_to_cstr_n(name, name_length) } .unwrap() .to_string(); - let sub_data_type = ArcFFI::cloned_from_ptr(sub_data_type_raw).unwrap(); - let data_type = ArcFFI::as_ref(data_type_raw).unwrap(); match unsafe { data_type.get_mut_unchecked() } { CassDataTypeInner::UDT(udt_data_type) => { // The Cpp Driver does not check whether field_types size diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs index c9e0a860..043f78fe 100644 --- a/scylla-rust-wrapper/src/cluster.rs +++ b/scylla-rust-wrapper/src/cluster.rs @@ -381,7 +381,10 @@ pub unsafe extern "C" fn cass_cluster_set_application_name_n( app_name: *const c_char, app_name_len: size_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_application_name_n!"); + return; + }; let app_name = unsafe { ptr_to_cstr_n(app_name, app_name_len) } .unwrap() .to_string(); @@ -407,7 +410,10 @@ pub unsafe extern "C" fn cass_cluster_set_application_version_n( app_version: *const c_char, app_version_len: size_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_application_version_n!"); + return; + }; let app_version = unsafe { ptr_to_cstr_n(app_version, app_version_len) } .unwrap() .to_string(); @@ -424,7 +430,10 @@ pub unsafe extern "C" fn cass_cluster_set_client_id( cluster_raw: CassBorrowedExclusivePtr, client_id: CassUuid, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_client_id!"); + return; + }; let client_uuid: uuid::Uuid = client_id.into(); let client_uuid_str = client_uuid.to_string(); @@ -442,7 +451,11 @@ pub unsafe extern "C" fn cass_cluster_set_use_schema( cluster_raw: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_use_schema!"); + return; + }; + cluster.session_builder.config.fetch_schema_metadata = enabled != 0; } @@ -451,7 +464,11 @@ pub unsafe extern "C" fn cass_cluster_set_tcp_nodelay( cluster_raw: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_tcp_nodelay!"); + return; + }; + cluster.session_builder.config.tcp_nodelay = enabled != 0; } @@ -461,7 +478,11 @@ pub unsafe extern "C" fn cass_cluster_set_tcp_keepalive( enabled: cass_bool_t, delay_secs: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_tcp_keepalive!"); + return; + }; + let enabled = enabled != 0; let tcp_keepalive_interval = enabled.then(|| Duration::from_secs(delay_secs as u64)); @@ -500,7 +521,13 @@ pub unsafe extern "C" fn cass_cluster_set_connection_heartbeat_interval( cluster_raw: CassBorrowedExclusivePtr, interval_secs: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_connection_heartbeat_interval!" + ); + return; + }; + let keepalive_interval = (interval_secs > 0).then(|| Duration::from_secs(interval_secs as u64)); cluster.session_builder.config.keepalive_interval = keepalive_interval; @@ -511,7 +538,13 @@ pub unsafe extern "C" fn cass_cluster_set_connection_idle_timeout( cluster_raw: CassBorrowedExclusivePtr, timeout_secs: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_connection_idle_timeout!" + ); + return; + }; + let keepalive_timeout = (timeout_secs > 0).then(|| Duration::from_secs(timeout_secs as u64)); cluster.session_builder.config.keepalive_timeout = keepalive_timeout; @@ -522,7 +555,11 @@ pub unsafe extern "C" fn cass_cluster_set_connect_timeout( cluster_raw: CassBorrowedExclusivePtr, timeout_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_connect_timeout!"); + return; + }; + cluster.session_builder.config.connect_timeout = Duration::from_millis(timeout_ms.into()); } @@ -618,7 +655,10 @@ pub unsafe extern "C" fn cass_cluster_set_request_timeout( cluster_raw: CassBorrowedExclusivePtr, timeout_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_request_timeout!"); + return; + }; exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| { // 0 -> no timeout @@ -631,7 +671,10 @@ pub unsafe extern "C" fn cass_cluster_set_max_schema_wait_time( cluster_raw: CassBorrowedExclusivePtr, wait_time_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_max_schema_wait_time!"); + return; + }; cluster.session_builder.config.schema_agreement_timeout = Duration::from_millis(wait_time_ms.into()); @@ -642,7 +685,12 @@ pub unsafe extern "C" fn cass_cluster_set_schema_agreement_interval( cluster_raw: CassBorrowedExclusivePtr, interval_ms: c_uint, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_schema_agreement_interval!" + ); + return; + }; cluster.session_builder.config.schema_agreement_interval = Duration::from_millis(interval_ms.into()); @@ -653,12 +701,16 @@ pub unsafe extern "C" fn cass_cluster_set_port( cluster_raw: CassBorrowedExclusivePtr, port: c_int, ) -> CassError { - if port <= 0 { + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_port!"); return CassError::CASS_ERROR_LIB_BAD_PARAMS; - } + }; + let Ok(port): Result = port.try_into() else { + tracing::error!("Provided invalid port number to cass_cluster_set_port!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); - cluster.port = port as u16; + cluster.port = port; CassError::CASS_OK } @@ -774,11 +826,14 @@ pub unsafe extern "C" fn cass_cluster_set_credentials_n( password_raw: *const c_char, password_length: size_t, ) { + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_credentials_n!"); + return; + }; // TODO: string error handling let username = unsafe { ptr_to_cstr_n(username_raw, username_length) }.unwrap(); let password = unsafe { ptr_to_cstr_n(password_raw, password_length) }.unwrap(); - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); cluster.auth_username = Some(username.to_string()); cluster.auth_password = Some(password.to_string()); } @@ -787,7 +842,13 @@ pub unsafe extern "C" fn cass_cluster_set_credentials_n( pub unsafe extern "C" fn cass_cluster_set_load_balance_round_robin( cluster_raw: CassBorrowedExclusivePtr, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_load_balance_round_robin!" + ); + return; + }; + cluster.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); } @@ -842,7 +903,12 @@ pub unsafe extern "C" fn cass_cluster_set_load_balance_dc_aware_n( used_hosts_per_remote_dc: c_uint, allow_remote_dcs_for_local_cl: cass_bool_t, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_load_balance_dc_aware_n!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; unsafe { set_load_balance_dc_aware_n( @@ -880,7 +946,12 @@ pub unsafe extern "C" fn cass_cluster_set_load_balance_rack_aware_n( local_rack_raw: *const c_char, local_rack_length: size_t, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_load_balance_rack_aware_n!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; unsafe { set_load_balance_rack_aware_n( @@ -996,7 +1067,13 @@ pub unsafe extern "C" fn cass_cluster_set_use_beta_protocol_version( cluster_raw: CassBorrowedExclusivePtr, enable: cass_bool_t, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_use_beta_protocol_version!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + cluster.use_beta_protocol_version = enable == cass_true; CassError::CASS_OK @@ -1007,7 +1084,10 @@ pub unsafe extern "C" fn cass_cluster_set_protocol_version( cluster_raw: CassBorrowedExclusivePtr, protocol_version: c_int, ) -> CassError { - let cluster = BoxFFI::as_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_protocol_version!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; if protocol_version == 4 && !cluster.use_beta_protocol_version { // Rust Driver supports only protocol version 4 @@ -1032,12 +1112,17 @@ pub unsafe extern "C" fn cass_cluster_set_constant_speculative_execution_policy( constant_delay_ms: cass_int64_t, max_speculative_executions: c_int, ) -> CassError { + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_constant_speculative_execution_policy!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + if constant_delay_ms < 0 || max_speculative_executions < 0 { return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); - let policy = SimpleSpeculativeExecutionPolicy { max_retry_count: max_speculative_executions as usize, retry_interval: Duration::from_millis(constant_delay_ms as u64), @@ -1054,7 +1139,12 @@ pub unsafe extern "C" fn cass_cluster_set_constant_speculative_execution_policy( pub unsafe extern "C" fn cass_cluster_set_no_speculative_execution_policy( cluster_raw: CassBorrowedExclusivePtr, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_no_speculative_execution_policy!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| { builder.speculative_execution_policy(None) @@ -1068,7 +1158,11 @@ pub unsafe extern "C" fn cass_cluster_set_token_aware_routing( cluster_raw: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_token_aware_routing!"); + return; + }; + cluster.load_balancing_config.token_awareness_enabled = enabled != 0; } @@ -1077,7 +1171,12 @@ pub unsafe extern "C" fn cass_cluster_set_token_aware_routing_shuffle_replicas( cluster_raw: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_token_aware_routing_shuffle_replicas!" + ); + return; + }; cluster .load_balancing_config @@ -1089,12 +1188,19 @@ pub unsafe extern "C" fn cass_cluster_set_retry_policy( cluster_raw: CassBorrowedExclusivePtr, retry_policy: CassBorrowedSharedPtr, ) { - let cluster = BoxFFI::as_mut_ref(cluster_raw).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_retry_policy!"); + return; + }; - let retry_policy: Arc = match ArcFFI::as_ref(retry_policy).unwrap() { - DefaultRetryPolicy(default) => Arc::clone(default) as _, - FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _, - DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _, + let retry_policy: Arc = match ArcFFI::as_ref(retry_policy) { + Some(DefaultRetryPolicy(default)) => Arc::clone(default) as _, + Some(FallthroughRetryPolicy(fallthrough)) => Arc::clone(fallthrough) as _, + Some(DowngradingConsistencyRetryPolicy(downgrading)) => Arc::clone(downgrading) as _, + None => { + tracing::error!("Provided null retry policy pointer to cass_cluster_set_retry_policy!"); + return; + } }; exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| { @@ -1107,8 +1213,14 @@ pub unsafe extern "C" fn cass_cluster_set_ssl( cluster: CassBorrowedExclusivePtr, ssl: CassBorrowedSharedPtr, ) { - let cluster_from_raw = BoxFFI::as_mut_ref(cluster).unwrap(); - let cass_ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); + let Some(cluster_from_raw) = BoxFFI::as_mut_ref(cluster) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_ssl!"); + return; + }; + let Some(cass_ssl) = ArcFFI::as_ref(ssl) else { + tracing::error!("Provided null ssl pointer to cass_cluster_set_ssl!"); + return; + }; let ssl_context_builder = unsafe { SslContextBuilder::from_ptr(cass_ssl.ssl_context) }; // Reference count is increased as tokio_openssl will try to free `ssl_context` when calling `SSL_free`. @@ -1122,7 +1234,11 @@ pub unsafe extern "C" fn cass_cluster_set_compression( cluster: CassBorrowedExclusivePtr, compression_type: CassCompressionType, ) { - let cluster_from_raw = BoxFFI::as_mut_ref(cluster).unwrap(); + let Some(cluster_from_raw) = BoxFFI::as_mut_ref(cluster) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_compression!"); + return; + }; + let compression = match compression_type { CassCompressionType::CASS_COMPRESSION_LZ4 => Some(Compression::Lz4), CassCompressionType::CASS_COMPRESSION_SNAPPY => Some(Compression::Snappy), @@ -1137,7 +1253,11 @@ pub unsafe extern "C" fn cass_cluster_set_latency_aware_routing( cluster: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_latency_aware_routing!"); + return; + }; + cluster.load_balancing_config.latency_awareness_enabled = enabled != 0; } @@ -1150,7 +1270,13 @@ pub unsafe extern "C" fn cass_cluster_set_latency_aware_routing_settings( update_rate_ms: cass_uint64_t, min_measured: cass_uint64_t, ) { - let cluster = BoxFFI::as_mut_ref(cluster).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster) else { + tracing::error!( + "Provided null cluster pointer to cass_cluster_set_latency_aware_routing_settings!" + ); + return; + }; + cluster.load_balancing_config.latency_awareness_builder = LatencyAwarenessBuilder::new() .exclusion_threshold(exclusion_threshold) .scale(Duration::from_millis(scale_ms)) @@ -1164,7 +1290,11 @@ pub unsafe extern "C" fn cass_cluster_set_consistency( cluster: CassBorrowedExclusivePtr, consistency: CassConsistency, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_consistency!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let consistency: Consistency = match consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -1182,7 +1312,11 @@ pub unsafe extern "C" fn cass_cluster_set_serial_consistency( cluster: CassBorrowedExclusivePtr, serial_consistency: CassConsistency, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_serial_consistency!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let serial_consistency: SerialConsistency = match serial_consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -1211,7 +1345,11 @@ pub unsafe extern "C" fn cass_cluster_set_execution_profile_n( name_length: size_t, profile: CassBorrowedExclusivePtr, ) -> CassError { - let cluster = BoxFFI::as_mut_ref(cluster).unwrap(); + let Some(cluster) = BoxFFI::as_mut_ref(cluster) else { + tracing::error!("Provided null cluster pointer to cass_cluster_set_execution_profile_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let name = if let Some(name) = unsafe { ptr_to_cstr_n(name, name_length) }.and_then(|name| name.to_owned().try_into().ok()) { diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index f675c800..de326a3b 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -160,7 +160,11 @@ unsafe extern "C" fn cass_collection_new_from_data_type( data_type: CassBorrowedSharedPtr, item_count: size_t, ) -> CassOwnedExclusivePtr { - let data_type = ArcFFI::cloned_from_ptr(data_type).unwrap(); + let Some(data_type) = ArcFFI::cloned_from_ptr(data_type) else { + tracing::error!("Provided null data type pointer to cass_collection_new_from_data_type!"); + return BoxFFI::null_mut(); + }; + let (capacity, collection_type) = match unsafe { data_type.get_unchecked() } { CassDataTypeInner::List { .. } => { (item_count, CassCollectionType::CASS_COLLECTION_TYPE_LIST) @@ -187,7 +191,10 @@ unsafe extern "C" fn cass_collection_new_from_data_type( unsafe extern "C" fn cass_collection_data_type( collection: CassBorrowedSharedPtr, ) -> CassBorrowedSharedPtr { - let collection_ref = BoxFFI::as_ref(collection).unwrap(); + let Some(collection_ref) = BoxFFI::as_ref(collection) else { + tracing::error!("Provided null collection pointer to cass_collection_data_type!"); + return ArcFFI::null(); + }; match &collection_ref.data_type { Some(dt) => ArcFFI::as_ptr(dt), diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index fcc6c03e..2621a61a 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -204,7 +204,11 @@ pub unsafe extern "C" fn cass_statement_set_execution_profile_n( name: *const c_char, name_length: size_t, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement).unwrap(); + let Some(statement) = BoxFFI::as_mut_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_statement_set_execution_profile!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let name: Option = unsafe { ptr_to_cstr_n(name, name_length) } .and_then(|name| name.to_owned().try_into().ok()); statement.exec_profile = name.map(PerStatementExecProfile::new_unresolved); @@ -226,7 +230,11 @@ pub unsafe extern "C" fn cass_batch_set_execution_profile_n( name: *const c_char, name_length: size_t, ) -> CassError { - let batch = BoxFFI::as_mut_ref(batch).unwrap(); + let Some(batch) = BoxFFI::as_mut_ref(batch) else { + tracing::error!("Provided null batch pointer to cass_batch_set_execution_profile!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let name: Option = unsafe { ptr_to_cstr_n(name, name_length) } .and_then(|name| name.to_owned().try_into().ok()); batch.exec_profile = name.map(PerStatementExecProfile::new_unresolved); @@ -259,7 +267,11 @@ pub unsafe extern "C" fn cass_execution_profile_set_consistency( profile: CassBorrowedExclusivePtr, consistency: CassConsistency, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!("Provided null profile pointer to cass_execution_profile_set_consistency!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let consistency: Consistency = match consistency.try_into() { Ok(c) => c, Err(_) => return CassError::CASS_ERROR_LIB_BAD_PARAMS, @@ -274,7 +286,12 @@ pub unsafe extern "C" fn cass_execution_profile_set_consistency( pub unsafe extern "C" fn cass_execution_profile_set_no_speculative_execution_policy( profile: CassBorrowedExclusivePtr, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_no_speculative_execution_policy!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; profile_builder.modify_in_place(|builder| builder.speculative_execution_policy(None)); @@ -287,7 +304,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_constant_speculative_executi constant_delay_ms: cass_int64_t, max_speculative_executions: cass_int32_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_constant_speculative_execution_policy!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + if constant_delay_ms < 0 || max_speculative_executions < 0 { return CassError::CASS_ERROR_LIB_BAD_PARAMS; } @@ -308,7 +331,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing( profile: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_latency_aware_routing!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + profile_builder .load_balancing_config .latency_awareness_enabled = enabled != 0; @@ -325,7 +354,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_latency_aware_routing_settin update_rate_ms: cass_uint64_t, min_measured: cass_uint64_t, ) { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_latency_aware_routing_settings!" + ); + return; + }; + profile_builder .load_balancing_config .latency_awareness_builder = LatencyAwarenessBuilder::new() @@ -361,7 +396,12 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_dc_aware_n( used_hosts_per_remote_dc: cass_uint32_t, allow_remote_dcs_for_local_cl: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_load_balance_dc_aware!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; unsafe { set_load_balance_dc_aware_n( @@ -399,7 +439,12 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware_n( local_rack_raw: *const c_char, local_rack_length: size_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_load_balance_rack_aware!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; unsafe { set_load_balance_rack_aware_n( @@ -416,7 +461,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_rack_aware_n( pub unsafe extern "C" fn cass_execution_profile_set_load_balance_round_robin( profile: CassBorrowedExclusivePtr, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_load_balance_round_robin!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + profile_builder.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); CassError::CASS_OK @@ -427,7 +478,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_request_timeout( profile: CassBorrowedExclusivePtr, timeout_ms: cass_uint64_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_request_timeout!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + profile_builder.modify_in_place(|builder| { builder.request_timeout(Some(std::time::Duration::from_millis(timeout_ms))) }); @@ -440,12 +497,24 @@ pub unsafe extern "C" fn cass_execution_profile_set_retry_policy( profile: CassBorrowedExclusivePtr, retry_policy: CassBorrowedSharedPtr, ) -> CassError { - let retry_policy: Arc = match ArcFFI::as_ref(retry_policy).unwrap() { - DefaultRetryPolicy(default) => Arc::clone(default) as _, - FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _, - DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _, + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_retry_policy!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; }; - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let retry_policy: Arc = match ArcFFI::as_ref(retry_policy) { + Some(DefaultRetryPolicy(default)) => Arc::clone(default) as _, + Some(FallthroughRetryPolicy(fallthrough)) => Arc::clone(fallthrough) as _, + Some(DowngradingConsistencyRetryPolicy(downgrading)) => Arc::clone(downgrading) as _, + None => { + tracing::error!( + "Provided null retry policy pointer to cass_execution_profile_set_retry_policy!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + } + }; + profile_builder.modify_in_place(|builder| builder.retry_policy(retry_policy)); CassError::CASS_OK @@ -456,7 +525,12 @@ pub unsafe extern "C" fn cass_execution_profile_set_serial_consistency( profile: CassBorrowedExclusivePtr, serial_consistency: CassConsistency, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_serial_consistency!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; let maybe_serial_consistency = if serial_consistency == CassConsistency::CASS_CONSISTENCY_UNKNOWN { @@ -477,7 +551,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_token_aware_routing( profile: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_token_aware_routing!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + profile_builder .load_balancing_config .token_awareness_enabled = enabled != 0; @@ -490,7 +570,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_token_aware_routing_shuffle_ profile: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) -> CassError { - let profile_builder = BoxFFI::as_mut_ref(profile).unwrap(); + let Some(profile_builder) = BoxFFI::as_mut_ref(profile) else { + tracing::error!( + "Provided null profile pointer to cass_execution_profile_set_token_aware_routing_shuffle_replicas!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + profile_builder .load_balancing_config .token_aware_shuffling_replicas_enabled = enabled != 0; diff --git a/scylla-rust-wrapper/src/execution_error.rs b/scylla-rust-wrapper/src/execution_error.rs index 8351004b..39c346ea 100644 --- a/scylla-rust-wrapper/src/execution_error.rs +++ b/scylla-rust-wrapper/src/execution_error.rs @@ -68,7 +68,11 @@ pub unsafe extern "C" fn cass_error_result_free( pub unsafe extern "C" fn cass_error_result_code( error_result: CassBorrowedSharedPtr, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_code!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + error_result.to_cass_error() } @@ -76,7 +80,11 @@ pub unsafe extern "C" fn cass_error_result_code( pub unsafe extern "C" fn cass_error_result_consistency( error_result: CassBorrowedSharedPtr, ) -> CassConsistency { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_consistency!"); + return CassConsistency::CASS_CONSISTENCY_UNKNOWN; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::Unavailable { consistency, .. }, _) @@ -93,7 +101,13 @@ pub unsafe extern "C" fn cass_error_result_consistency( pub unsafe extern "C" fn cass_error_result_responses_received( error_result: CassBorrowedSharedPtr, ) -> cass_int32_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!( + "Provided null error result pointer to cass_error_result_responses_received!" + ); + return -1; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError(attempt_error)) => { match attempt_error { @@ -117,7 +131,13 @@ pub unsafe extern "C" fn cass_error_result_responses_received( pub unsafe extern "C" fn cass_error_result_responses_required( error_result: CassBorrowedSharedPtr, ) -> cass_int32_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!( + "Provided null error result pointer to cass_error_result_responses_required!" + ); + return -1; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError(attempt_error)) => { match attempt_error { @@ -141,7 +161,11 @@ pub unsafe extern "C" fn cass_error_result_responses_required( pub unsafe extern "C" fn cass_error_result_num_failures( error_result: CassBorrowedSharedPtr, ) -> cass_int32_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_num_failures!"); + return -1; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::ReadFailure { numfailures, .. }, _), @@ -157,7 +181,11 @@ pub unsafe extern "C" fn cass_error_result_num_failures( pub unsafe extern "C" fn cass_error_result_data_present( error_result: CassBorrowedSharedPtr, ) -> cass_bool_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_data_present!"); + return cass_false; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::ReadTimeout { data_present, .. }, _), @@ -185,7 +213,11 @@ pub unsafe extern "C" fn cass_error_result_data_present( pub unsafe extern "C" fn cass_error_result_write_type( error_result: CassBorrowedSharedPtr, ) -> CassWriteType { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_write_type!"); + return CassWriteType::CASS_WRITE_TYPE_UNKNOWN; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::WriteTimeout { write_type, .. }, _), @@ -203,7 +235,11 @@ pub unsafe extern "C" fn cass_error_result_keyspace( c_keyspace: *mut *const ::std::os::raw::c_char, c_keyspace_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_keyspace!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::AlreadyExists { keyspace, .. }, _), @@ -227,7 +263,11 @@ pub unsafe extern "C" fn cass_error_result_table( c_table: *mut *const ::std::os::raw::c_char, c_table_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_table!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::AlreadyExists { table, .. }, _), @@ -245,7 +285,11 @@ pub unsafe extern "C" fn cass_error_result_function( c_function: *mut *const ::std::os::raw::c_char, c_function_len: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_function!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::FunctionFailure { function, .. }, _), @@ -261,7 +305,11 @@ pub unsafe extern "C" fn cass_error_result_function( pub unsafe extern "C" fn cass_error_num_arg_types( error_result: CassBorrowedSharedPtr, ) -> size_t { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_num_arg_types!"); + return 0; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::FunctionFailure { arg_types, .. }, _), @@ -277,7 +325,11 @@ pub unsafe extern "C" fn cass_error_result_arg_type( arg_type: *mut *const ::std::os::raw::c_char, arg_type_length: *mut size_t, ) -> CassError { - let error_result: &CassErrorResult = ArcFFI::as_ref(error_result).unwrap(); + let Some(error_result) = ArcFFI::as_ref(error_result) else { + tracing::error!("Provided null error result pointer to cass_error_result_arg_type!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + match error_result { CassErrorResult::Execution(ExecutionError::LastAttemptError( RequestAttemptError::DbError(DbError::FunctionFailure { arg_types, .. }, _), diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 3ef9b1ce..0b279a3c 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -303,20 +303,22 @@ pub unsafe extern "C" fn cass_future_set_callback( callback: CassFutureCallback, data: *mut ::std::os::raw::c_void, ) -> CassError { - unsafe { - ArcFFI::as_ref(future_raw.borrow()).unwrap().set_callback( - future_raw.borrow(), - callback, - data, - ) - } + let Some(future) = ArcFFI::as_ref(future_raw.borrow()) else { + tracing::error!("Provided null future pointer to cass_future_set_callback!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + unsafe { future.set_callback(future_raw.borrow(), callback, data) } } #[unsafe(no_mangle)] pub unsafe extern "C" fn cass_future_wait(future_raw: CassBorrowedSharedPtr) { - ArcFFI::as_ref(future_raw) - .unwrap() - .with_waited_result(|_| ()); + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to cass_future_wait!"); + return; + }; + + future.with_waited_result(|_| ()); } #[unsafe(no_mangle)] @@ -324,8 +326,12 @@ pub unsafe extern "C" fn cass_future_wait_timed( future_raw: CassBorrowedSharedPtr, timeout_us: cass_duration_t, ) -> cass_bool_t { - ArcFFI::as_ref(future_raw) - .unwrap() + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to cass_future_wait_timed!"); + return cass_false; + }; + + future .with_waited_result_timed(|_| (), Duration::from_micros(timeout_us)) .is_ok() as cass_bool_t } @@ -334,7 +340,12 @@ pub unsafe extern "C" fn cass_future_wait_timed( pub unsafe extern "C" fn cass_future_ready( future_raw: CassBorrowedSharedPtr, ) -> cass_bool_t { - let state_guard = ArcFFI::as_ref(future_raw).unwrap().state.lock().unwrap(); + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to cass_future_ready!"); + return cass_false; + }; + + let state_guard = future.state.lock().unwrap(); match state_guard.value { None => cass_false, Some(_) => cass_true, @@ -345,13 +356,16 @@ pub unsafe extern "C" fn cass_future_ready( pub unsafe extern "C" fn cass_future_error_code( future_raw: CassBorrowedSharedPtr, ) -> CassError { - ArcFFI::as_ref(future_raw) - .unwrap() - .with_waited_result(|r: &mut CassFutureResult| match r { - Ok(CassResultValue::QueryError(err)) => err.to_cass_error(), - Err((err, _)) => *err, - _ => CassError::CASS_OK, - }) + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to cass_future_error_code!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + future.with_waited_result(|r: &mut CassFutureResult| match r { + Ok(CassResultValue::QueryError(err)) => err.to_cass_error(), + Err((err, _)) => *err, + _ => CassError::CASS_OK, + }) } #[unsafe(no_mangle)] @@ -360,19 +374,22 @@ pub unsafe extern "C" fn cass_future_error_message( message: *mut *const ::std::os::raw::c_char, message_length: *mut size_t, ) { - ArcFFI::as_ref(future) - .unwrap() - .with_waited_state(|state: &mut CassFutureState| { - let value = &state.value; - let msg = state - .err_string - .get_or_insert_with(|| match value.as_ref().unwrap() { - Ok(CassResultValue::QueryError(err)) => err.msg(), - Err((_, s)) => s.msg(), - _ => "".to_string(), - }); - unsafe { write_str_to_c(msg.as_str(), message, message_length) }; - }); + let Some(future) = ArcFFI::as_ref(future) else { + tracing::error!("Provided null future pointer to cass_future_error_message!"); + return; + }; + + future.with_waited_state(|state: &mut CassFutureState| { + let value = &state.value; + let msg = state + .err_string + .get_or_insert_with(|| match value.as_ref().unwrap() { + Ok(CassResultValue::QueryError(err)) => err.msg(), + Err((_, s)) => s.msg(), + _ => "".to_string(), + }); + unsafe { write_str_to_c(msg.as_str(), message, message_length) }; + }); } #[unsafe(no_mangle)] @@ -384,8 +401,12 @@ pub unsafe extern "C" fn cass_future_free(future_raw: CassOwnedSharedPtr, ) -> CassOwnedSharedPtr { - ArcFFI::as_ref(future_raw) - .unwrap() + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to cass_future_get_result!"); + return ArcFFI::null(); + }; + + future .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryResult(qr) => Some(Arc::clone(qr)), @@ -399,8 +420,12 @@ pub unsafe extern "C" fn cass_future_get_result( pub unsafe extern "C" fn cass_future_get_error_result( future_raw: CassBorrowedSharedPtr, ) -> CassOwnedSharedPtr { - ArcFFI::as_ref(future_raw) - .unwrap() + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to cass_future_get_error_result!"); + return ArcFFI::null(); + }; + + future .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::QueryError(qr) => Some(Arc::clone(qr)), @@ -414,8 +439,12 @@ pub unsafe extern "C" fn cass_future_get_error_result( pub unsafe extern "C" fn cass_future_get_prepared( future_raw: CassBorrowedSharedPtr, ) -> CassOwnedSharedPtr { - ArcFFI::as_ref(future_raw) - .unwrap() + let Some(future) = ArcFFI::as_ref(future_raw) else { + tracing::error!("Provided null future pointer to cass_future_get_prepared!"); + return ArcFFI::null(); + }; + + future .with_waited_result(|r: &mut CassFutureResult| -> Option> { match r.as_ref().ok()? { CassResultValue::Prepared(p) => Some(Arc::clone(p)), @@ -430,18 +459,21 @@ pub unsafe extern "C" fn cass_future_tracing_id( future: CassBorrowedSharedPtr, tracing_id: *mut CassUuid, ) -> CassError { - ArcFFI::as_ref(future) - .unwrap() - .with_waited_result(|r: &mut CassFutureResult| match r { - Ok(CassResultValue::QueryResult(result)) => match result.tracing_id { - Some(id) => { - unsafe { *tracing_id = CassUuid::from(id) }; - CassError::CASS_OK - } - None => CassError::CASS_ERROR_LIB_NO_TRACING_ID, - }, - _ => CassError::CASS_ERROR_LIB_INVALID_FUTURE_TYPE, - }) + let Some(future) = ArcFFI::as_ref(future) else { + tracing::error!("Provided null future pointer to cass_future_tracing_id!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + future.with_waited_result(|r: &mut CassFutureResult| match r { + Ok(CassResultValue::QueryResult(result)) => match result.tracing_id { + Some(id) => { + unsafe { *tracing_id = CassUuid::from(id) }; + CassError::CASS_OK + } + None => CassError::CASS_ERROR_LIB_NO_TRACING_ID, + }, + _ => CassError::CASS_ERROR_LIB_INVALID_FUTURE_TYPE, + }) } #[cfg(test)] diff --git a/scylla-rust-wrapper/src/iterator.rs b/scylla-rust-wrapper/src/iterator.rs index 078039d0..e0325d7e 100644 --- a/scylla-rust-wrapper/src/iterator.rs +++ b/scylla-rust-wrapper/src/iterator.rs @@ -15,7 +15,7 @@ use crate::query_result::{ CassRawRow, CassResult, CassResultKind, CassResultMetadata, CassRow, CassValue, NonNullDeserializationError, cass_value_type, }; -use crate::types::{cass_bool_t, size_t}; +use crate::types::{cass_bool_t, cass_false, size_t}; pub use crate::cass_iterator_types::CassIteratorType; @@ -674,7 +674,11 @@ pub unsafe extern "C" fn cass_iterator_free(iterator: CassOwnedExclusivePtr, ) -> CassIteratorType { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_type!"); + // TYPE_RESULT corresponds to 0. + return CassIteratorType::CASS_ITERATOR_TYPE_RESULT; + }; match iter { CassIterator::Result(_) => CassIteratorType::CASS_ITERATOR_TYPE_RESULT, @@ -698,7 +702,10 @@ pub unsafe extern "C" fn cass_iterator_type( pub unsafe extern "C" fn cass_iterator_next( iterator: CassBorrowedExclusivePtr, ) -> cass_bool_t { - let mut iter = BoxFFI::as_mut_ref(iterator).unwrap(); + let Some(mut iter) = BoxFFI::as_mut_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_next!"); + return cass_false; + }; let result = match &mut iter { CassIterator::Result(result_iterator) => result_iterator.next(), @@ -731,7 +738,10 @@ pub unsafe extern "C" fn cass_iterator_next( pub unsafe extern "C" fn cass_iterator_get_row<'result>( iterator: CassBorrowedSharedPtr<'result, CassIterator<'result>, CConst>, ) -> CassBorrowedSharedPtr<'result, CassRow<'result>, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_row!"); + return RefFFI::null(); + }; // Defined only for result iterator, for other types should return null let CassIterator::Result(CassResultIterator::Rows(rows_result_iterator)) = iter else { @@ -749,7 +759,10 @@ pub unsafe extern "C" fn cass_iterator_get_row<'result>( pub unsafe extern "C" fn cass_iterator_get_column<'result>( iterator: CassBorrowedSharedPtr, CConst>, ) -> CassBorrowedSharedPtr<'result, CassValue<'result>, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_column!"); + return RefFFI::null(); + }; // Defined only for row iterator, for other types should return null if let CassIterator::Row(row_iterator) = iter { @@ -773,7 +786,10 @@ pub unsafe extern "C" fn cass_iterator_get_column<'result>( pub unsafe extern "C" fn cass_iterator_get_value<'result>( iterator: CassBorrowedSharedPtr<'result, CassIterator<'result>, CConst>, ) -> CassBorrowedSharedPtr<'result, CassValue<'result>, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_value!"); + return RefFFI::null(); + }; // Defined only for collections(list, set and map) or tuple iterator, for other types should return null match iter { @@ -814,7 +830,10 @@ pub unsafe extern "C" fn cass_iterator_get_value<'result>( pub unsafe extern "C" fn cass_iterator_get_map_key<'result>( iterator: CassBorrowedSharedPtr<'result, CassIterator<'result>, CConst>, ) -> CassBorrowedSharedPtr<'result, CassValue<'result>, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_map_key!"); + return RefFFI::null(); + }; let CassIterator::Map(map_iterator) = iter else { return RefFFI::null(); @@ -831,7 +850,10 @@ pub unsafe extern "C" fn cass_iterator_get_map_key<'result>( pub unsafe extern "C" fn cass_iterator_get_map_value<'result>( iterator: CassBorrowedSharedPtr<'result, CassIterator<'result>, CConst>, ) -> CassBorrowedSharedPtr<'result, CassValue<'result>, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_map_value!"); + return RefFFI::null(); + }; let CassIterator::Map(map_iterator) = iter else { return RefFFI::null(); @@ -850,7 +872,12 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!( + "Provided null iterator pointer to cass_iterator_get_user_type_field_name!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; let CassIterator::Udt(udt_iterator) = iter else { return CassError::CASS_ERROR_LIB_BAD_PARAMS; @@ -872,7 +899,12 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_name( pub unsafe extern "C" fn cass_iterator_get_user_type_field_value<'result>( iterator: CassBorrowedSharedPtr<'result, CassIterator<'result>, CConst>, ) -> CassBorrowedSharedPtr<'result, CassValue<'result>, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!( + "Provided null iterator pointer to cass_iterator_get_user_type_field_value!" + ); + return RefFFI::null(); + }; let CassIterator::Udt(udt_iterator) = iter else { return RefFFI::null(); @@ -889,7 +921,10 @@ pub unsafe extern "C" fn cass_iterator_get_user_type_field_value<'result>( pub unsafe extern "C" fn cass_iterator_get_keyspace_meta<'schema>( iterator: CassBorrowedSharedPtr, CConst>, ) -> CassBorrowedSharedPtr<'schema, CassKeyspaceMeta, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_keyspace_meta!"); + return RefFFI::null(); + }; if let CassIterator::KeyspacesMeta(schema_meta_iterator) = iter { let iter_position = match schema_meta_iterator.position { @@ -916,7 +951,10 @@ pub unsafe extern "C" fn cass_iterator_get_keyspace_meta<'schema>( pub unsafe extern "C" fn cass_iterator_get_table_meta<'schema>( iterator: CassBorrowedSharedPtr, CConst>, ) -> CassBorrowedSharedPtr<'schema, CassTableMeta, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_table_meta!"); + return RefFFI::null(); + }; if let CassIterator::TablesMeta(keyspace_meta_iterator) = iter { let iter_position = match keyspace_meta_iterator.position { @@ -943,7 +981,10 @@ pub unsafe extern "C" fn cass_iterator_get_table_meta<'schema>( pub unsafe extern "C" fn cass_iterator_get_user_type<'schema>( iterator: CassBorrowedSharedPtr, CConst>, ) -> CassBorrowedSharedPtr<'schema, CassDataType, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_user_type!"); + return ArcFFI::null(); + }; if let CassIterator::UserTypes(keyspace_meta_iterator) = iter { let iter_position = match keyspace_meta_iterator.position { @@ -970,7 +1011,10 @@ pub unsafe extern "C" fn cass_iterator_get_user_type<'schema>( pub unsafe extern "C" fn cass_iterator_get_column_meta<'schema>( iterator: CassBorrowedSharedPtr, CConst>, ) -> CassBorrowedSharedPtr<'schema, CassColumnMeta, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!("Provided null iterator pointer to cass_iterator_get_column_meta!"); + return RefFFI::null(); + }; match iter { CassIterator::ColumnsMeta(CassColumnsMetaIterator::FromTable(table_meta_iterator)) => { @@ -1016,7 +1060,12 @@ pub unsafe extern "C" fn cass_iterator_get_column_meta<'schema>( pub unsafe extern "C" fn cass_iterator_get_materialized_view_meta<'schema>( iterator: CassBorrowedSharedPtr, CConst>, ) -> CassBorrowedSharedPtr<'schema, CassMaterializedViewMeta, CConst> { - let iter = BoxFFI::as_ref(iterator).unwrap(); + let Some(iter) = BoxFFI::as_ref(iterator) else { + tracing::error!( + "Provided null iterator pointer to cass_iterator_get_materialized_view_meta!" + ); + return RefFFI::null(); + }; match iter { CassIterator::MaterializedViewsMeta(CassMaterializedViewsMetaIterator::FromKeyspace( @@ -1058,7 +1107,10 @@ pub unsafe extern "C" fn cass_iterator_get_materialized_view_meta<'schema>( pub unsafe extern "C" fn cass_iterator_from_result<'result>( result: CassBorrowedSharedPtr<'result, CassResult, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let result_from_raw = ArcFFI::as_ref(result).unwrap(); + let Some(result_from_raw) = ArcFFI::as_ref(result) else { + tracing::error!("Provided null result pointer to cass_iterator_from_result!"); + return BoxFFI::null_mut(); + }; let iterator = match &result_from_raw.kind { CassResultKind::NonRows => CassResultIterator::NonRows, @@ -1084,7 +1136,10 @@ pub unsafe extern "C" fn cass_iterator_from_result<'result>( pub unsafe extern "C" fn cass_iterator_from_row<'result>( row: CassBorrowedSharedPtr<'result, CassRow<'result>, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let row_from_raw = RefFFI::as_ref(row).unwrap(); + let Some(row_from_raw) = RefFFI::as_ref(row) else { + tracing::error!("Provided null row pointer to cass_iterator_from_row!"); + return BoxFFI::null_mut(); + }; let iterator = CassRowIterator { row: row_from_raw, @@ -1134,7 +1189,10 @@ pub unsafe extern "C" fn cass_iterator_from_collection<'result>( pub unsafe extern "C" fn cass_iterator_from_tuple<'result>( value: CassBorrowedSharedPtr<'result, CassValue<'result>, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let tuple = RefFFI::as_ref(value).unwrap(); + let Some(tuple) = RefFFI::as_ref(value) else { + tracing::error!("Provided null tuple pointer to cass_iterator_from_tuple!"); + return BoxFFI::null_mut(); + }; let iterator_result = CassTupleIterator::new_from_value(tuple); match iterator_result { @@ -1151,7 +1209,10 @@ pub unsafe extern "C" fn cass_iterator_from_tuple<'result>( pub unsafe extern "C" fn cass_iterator_from_map<'result>( value: CassBorrowedSharedPtr<'result, CassValue<'result>, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let map = RefFFI::as_ref(value).unwrap(); + let Some(map) = RefFFI::as_ref(value) else { + tracing::error!("Provided null map pointer to cass_iterator_from_map!"); + return BoxFFI::null_mut(); + }; let iterator_result = CassMapIterator::new_from_value(map); @@ -1169,7 +1230,10 @@ pub unsafe extern "C" fn cass_iterator_from_map<'result>( pub unsafe extern "C" fn cass_iterator_fields_from_user_type<'result>( value: CassBorrowedSharedPtr<'result, CassValue<'result>, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let udt = RefFFI::as_ref(value).unwrap(); + let Some(udt) = RefFFI::as_ref(value) else { + tracing::error!("Provided null UDT pointer to cass_iterator_fields_from_user_type!"); + return BoxFFI::null_mut(); + }; let iterator_result = CassUdtIterator::new_from_value(udt); match iterator_result { @@ -1186,7 +1250,12 @@ pub unsafe extern "C" fn cass_iterator_fields_from_user_type<'result>( pub unsafe extern "C" fn cass_iterator_keyspaces_from_schema_meta<'schema>( schema_meta: CassBorrowedSharedPtr<'schema, CassSchemaMeta, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let metadata = BoxFFI::as_ref(schema_meta).unwrap(); + let Some(metadata) = BoxFFI::as_ref(schema_meta) else { + tracing::error!( + "Provided null schema metadata pointer to cass_iterator_keyspaces_from_schema_meta!" + ); + return BoxFFI::null_mut(); + }; let iterator = CassSchemaMetaIterator { value: metadata, @@ -1202,7 +1271,12 @@ pub unsafe extern "C" fn cass_iterator_keyspaces_from_schema_meta<'schema>( pub unsafe extern "C" fn cass_iterator_tables_from_keyspace_meta<'schema>( keyspace_meta: CassBorrowedSharedPtr<'schema, CassKeyspaceMeta, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let metadata = RefFFI::as_ref(keyspace_meta).unwrap(); + let Some(metadata) = RefFFI::as_ref(keyspace_meta) else { + tracing::error!( + "Provided null keyspace metadata pointer to cass_iterator_tables_from_keyspace_meta!" + ); + return BoxFFI::null_mut(); + }; let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -1218,7 +1292,12 @@ pub unsafe extern "C" fn cass_iterator_tables_from_keyspace_meta<'schema>( pub unsafe extern "C" fn cass_iterator_materialized_views_from_keyspace_meta<'schema>( keyspace_meta: CassBorrowedSharedPtr<'schema, CassKeyspaceMeta, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let metadata = RefFFI::as_ref(keyspace_meta).unwrap(); + let Some(metadata) = RefFFI::as_ref(keyspace_meta) else { + tracing::error!( + "Provided null keyspace metadata pointer to cass_iterator_materialized_views_from_keyspace_meta!" + ); + return BoxFFI::null_mut(); + }; let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -1236,7 +1315,12 @@ pub unsafe extern "C" fn cass_iterator_materialized_views_from_keyspace_meta<'sc pub unsafe extern "C" fn cass_iterator_user_types_from_keyspace_meta<'schema>( keyspace_meta: CassBorrowedSharedPtr<'schema, CassKeyspaceMeta, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let metadata = RefFFI::as_ref(keyspace_meta).unwrap(); + let Some(metadata) = RefFFI::as_ref(keyspace_meta) else { + tracing::error!( + "Provided null keyspace metadata pointer to cass_iterator_user_types_from_keyspace_meta!" + ); + return BoxFFI::null_mut(); + }; let iterator = CassKeyspaceMetaIterator { value: metadata, @@ -1252,7 +1336,12 @@ pub unsafe extern "C" fn cass_iterator_user_types_from_keyspace_meta<'schema>( pub unsafe extern "C" fn cass_iterator_columns_from_table_meta<'schema>( table_meta: CassBorrowedSharedPtr<'schema, CassTableMeta, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let metadata = RefFFI::as_ref(table_meta).unwrap(); + let Some(metadata) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_iterator_columns_from_table_meta!" + ); + return BoxFFI::null_mut(); + }; let iterator = CassTableMetaIterator { value: metadata, @@ -1270,7 +1359,12 @@ pub unsafe extern "C" fn cass_iterator_columns_from_table_meta<'schema>( pub unsafe extern "C" fn cass_iterator_materialized_views_from_table_meta<'schema>( table_meta: CassBorrowedSharedPtr<'schema, CassTableMeta, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let metadata = RefFFI::as_ref(table_meta).unwrap(); + let Some(metadata) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_iterator_materialized_views_from_table_meta!" + ); + return BoxFFI::null_mut(); + }; let iterator = CassTableMetaIterator { value: metadata, @@ -1288,7 +1382,12 @@ pub unsafe extern "C" fn cass_iterator_materialized_views_from_table_meta<'schem pub unsafe extern "C" fn cass_iterator_columns_from_materialized_view_meta<'schema>( view_meta: CassBorrowedSharedPtr<'schema, CassMaterializedViewMeta, CConst>, ) -> CassOwnedExclusivePtr, CMut> { - let metadata = RefFFI::as_ref(view_meta).unwrap(); + let Some(metadata) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null view metadata pointer to cass_iterator_columns_from_materialized_view_meta!" + ); + return BoxFFI::null_mut(); + }; let iterator = CassViewMetaIterator { value: metadata, diff --git a/scylla-rust-wrapper/src/metadata.rs b/scylla-rust-wrapper/src/metadata.rs index 3c94d146..576ec85c 100644 --- a/scylla-rust-wrapper/src/metadata.rs +++ b/scylla-rust-wrapper/src/metadata.rs @@ -139,11 +139,16 @@ pub unsafe extern "C" fn cass_schema_meta_keyspace_by_name_n( keyspace_name: *const c_char, keyspace_name_length: size_t, ) -> CassBorrowedSharedPtr { + let Some(metadata) = BoxFFI::as_ref(schema_meta) else { + tracing::error!( + "Provided null schema metadata pointer to cass_schema_meta_keyspace_by_name_n!" + ); + return RefFFI::null(); + }; if keyspace_name.is_null() { return RefFFI::null(); } - let metadata = BoxFFI::as_ref(schema_meta).unwrap(); let keyspace = unsafe { ptr_to_cstr_n(keyspace_name, keyspace_name_length) }.unwrap(); let keyspace_meta = metadata.keyspaces.get(keyspace); @@ -160,7 +165,11 @@ pub unsafe extern "C" fn cass_keyspace_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let keyspace_meta = RefFFI::as_ref(keyspace_meta).unwrap(); + let Some(keyspace_meta) = RefFFI::as_ref(keyspace_meta) else { + tracing::error!("Provided null keyspace metadata pointer to cass_keyspace_meta_name!"); + return; + }; + unsafe { write_str_to_c(keyspace_meta.name.as_str(), name, name_length) } } @@ -178,11 +187,16 @@ pub unsafe extern "C" fn cass_keyspace_meta_user_type_by_name_n( type_: *const c_char, type_length: size_t, ) -> CassBorrowedSharedPtr { + let Some(keyspace_meta) = RefFFI::as_ref(keyspace_meta) else { + tracing::error!( + "Provided null keyspace metadata pointer to cass_keyspace_meta_user_type_by_name_n!" + ); + return ArcFFI::null(); + }; if type_.is_null() { return ArcFFI::null(); } - let keyspace_meta = RefFFI::as_ref(keyspace_meta).unwrap(); let user_type_name = unsafe { ptr_to_cstr_n(type_, type_length) }.unwrap(); match keyspace_meta @@ -208,11 +222,16 @@ pub unsafe extern "C" fn cass_keyspace_meta_table_by_name_n( table: *const c_char, table_length: size_t, ) -> CassBorrowedSharedPtr { + let Some(keyspace_meta) = RefFFI::as_ref(keyspace_meta) else { + tracing::error!( + "Provided null keyspace metadata pointer to cass_keyspace_meta_table_by_name_n!" + ); + return RefFFI::null(); + }; if table.is_null() { return RefFFI::null(); } - let keyspace_meta = RefFFI::as_ref(keyspace_meta).unwrap(); let table_name = unsafe { ptr_to_cstr_n(table, table_length) }.unwrap(); let table_meta = keyspace_meta.tables.get(table_name); @@ -229,7 +248,11 @@ pub unsafe extern "C" fn cass_table_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!("Provided null table metadata pointer to cass_table_meta_name!"); + return; + }; + unsafe { write_str_to_c(table_meta.name.as_str(), name, name_length) } } @@ -237,7 +260,11 @@ pub unsafe extern "C" fn cass_table_meta_name( pub unsafe extern "C" fn cass_table_meta_column_count( table_meta: CassBorrowedSharedPtr, ) -> size_t { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!("Provided null table metadata pointer to cass_table_meta_column_count!"); + return 0; + }; + table_meta.columns_metadata.len() as size_t } @@ -264,7 +291,10 @@ pub unsafe extern "C" fn cass_table_meta_column( // Then cks by position: h, i // Then remaining columns alphabetically: b, c, f, g - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!("Provided null table metadata pointer to cass_table_meta_column!"); + return RefFFI::null(); + }; let index = index as usize; // Check if the index lands in partition keys. If so, simply return the corresponding column. @@ -299,7 +329,10 @@ pub unsafe extern "C" fn cass_table_meta_partition_key( table_meta: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!("Provided null table metadata pointer to cass_table_meta_partition_key!"); + return RefFFI::null(); + }; match table_meta.partition_keys.get(index as usize) { Some(column_name) => match table_meta.columns_metadata.get(column_name) { @@ -314,7 +347,13 @@ pub unsafe extern "C" fn cass_table_meta_partition_key( pub unsafe extern "C" fn cass_table_meta_partition_key_count( table_meta: CassBorrowedSharedPtr, ) -> size_t { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_table_meta_partition_key_count!" + ); + return 0; + }; + table_meta.partition_keys.len() as size_t } @@ -323,7 +362,10 @@ pub unsafe extern "C" fn cass_table_meta_clustering_key( table_meta: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!("Provided null table metadata pointer to cass_table_meta_clustering_key!"); + return RefFFI::null(); + }; match table_meta.clustering_keys.get(index as usize) { Some(column_name) => match table_meta.columns_metadata.get(column_name) { @@ -338,7 +380,13 @@ pub unsafe extern "C" fn cass_table_meta_clustering_key( pub unsafe extern "C" fn cass_table_meta_clustering_key_count( table_meta: CassBorrowedSharedPtr, ) -> size_t { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_table_meta_clustering_key_count!" + ); + return 0; + }; + table_meta.clustering_keys.len() as size_t } @@ -356,11 +404,16 @@ pub unsafe extern "C" fn cass_table_meta_column_by_name_n( column: *const c_char, column_length: size_t, ) -> CassBorrowedSharedPtr { + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_table_meta_column_by_name_n!" + ); + return RefFFI::null(); + }; if column.is_null() { return RefFFI::null(); } - let table_meta = RefFFI::as_ref(table_meta).unwrap(); let column_name = unsafe { ptr_to_cstr_n(column, column_length) }.unwrap(); match table_meta.columns_metadata.get(column_name) { @@ -375,7 +428,11 @@ pub unsafe extern "C" fn cass_column_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let column_meta = RefFFI::as_ref(column_meta).unwrap(); + let Some(column_meta) = RefFFI::as_ref(column_meta) else { + tracing::error!("Provided null column metadata pointer to cass_column_meta_name!"); + return; + }; + unsafe { write_str_to_c(column_meta.name.as_str(), name, name_length) } } @@ -383,7 +440,11 @@ pub unsafe extern "C" fn cass_column_meta_name( pub unsafe extern "C" fn cass_column_meta_data_type( column_meta: CassBorrowedSharedPtr, ) -> CassBorrowedSharedPtr { - let column_meta = RefFFI::as_ref(column_meta).unwrap(); + let Some(column_meta) = RefFFI::as_ref(column_meta) else { + tracing::error!("Provided null column metadata pointer to cass_column_meta_data_type!"); + return ArcFFI::null(); + }; + ArcFFI::as_ptr(&column_meta.column_type) } @@ -391,7 +452,11 @@ pub unsafe extern "C" fn cass_column_meta_data_type( pub unsafe extern "C" fn cass_column_meta_type( column_meta: CassBorrowedSharedPtr, ) -> CassColumnType { - let column_meta = RefFFI::as_ref(column_meta).unwrap(); + let Some(column_meta) = RefFFI::as_ref(column_meta) else { + tracing::error!("Provided null column metadata pointer to cass_column_meta_type!"); + return CassColumnType::CASS_COLUMN_TYPE_REGULAR; + }; + column_meta.column_kind } @@ -409,11 +474,16 @@ pub unsafe extern "C" fn cass_keyspace_meta_materialized_view_by_name_n( view: *const c_char, view_length: size_t, ) -> CassBorrowedSharedPtr { + let Some(keyspace_meta) = RefFFI::as_ref(keyspace_meta) else { + tracing::error!( + "Provided null keyspace metadata pointer to cass_keyspace_meta_materialized_view_by_name_n!" + ); + return RefFFI::null(); + }; if view.is_null() { return RefFFI::null(); } - let keyspace_meta = RefFFI::as_ref(keyspace_meta).unwrap(); let view_name = unsafe { ptr_to_cstr_n(view, view_length).unwrap() }; match keyspace_meta.views.get(view_name) { @@ -436,11 +506,16 @@ pub unsafe extern "C" fn cass_table_meta_materialized_view_by_name_n( view: *const c_char, view_length: size_t, ) -> CassBorrowedSharedPtr { + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_table_meta_materialized_view_by_name_n!" + ); + return RefFFI::null(); + }; if view.is_null() { return RefFFI::null(); } - let table_meta = RefFFI::as_ref(table_meta).unwrap(); let view_name = unsafe { ptr_to_cstr_n(view, view_length).unwrap() }; match table_meta.views.get(view_name) { @@ -453,7 +528,13 @@ pub unsafe extern "C" fn cass_table_meta_materialized_view_by_name_n( pub unsafe extern "C" fn cass_table_meta_materialized_view_count( table_meta: CassBorrowedSharedPtr, ) -> size_t { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_table_meta_materialized_view_count!" + ); + return 0; + }; + table_meta.views.len() as size_t } @@ -462,7 +543,12 @@ pub unsafe extern "C" fn cass_table_meta_materialized_view( table_meta: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let table_meta = RefFFI::as_ref(table_meta).unwrap(); + let Some(table_meta) = RefFFI::as_ref(table_meta) else { + tracing::error!( + "Provided null table metadata pointer to cass_table_meta_materialized_view!" + ); + return RefFFI::null(); + }; match table_meta.views.iter().nth(index as usize) { Some(view_meta) => RefFFI::as_ptr(view_meta.1.as_ref()), @@ -484,11 +570,17 @@ pub unsafe extern "C" fn cass_materialized_view_meta_column_by_name_n( column: *const c_char, column_length: size_t, ) -> CassBorrowedSharedPtr { + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_column_by_name_n!" + ); + return RefFFI::null(); + }; + if column.is_null() { return RefFFI::null(); } - let view_meta = RefFFI::as_ref(view_meta).unwrap(); let column_name = unsafe { ptr_to_cstr_n(column, column_length).unwrap() }; match view_meta.view_metadata.columns_metadata.get(column_name) { @@ -503,7 +595,13 @@ pub unsafe extern "C" fn cass_materialized_view_meta_name( name: *mut *const c_char, name_length: *mut size_t, ) { - let view_meta = RefFFI::as_ref(view_meta).unwrap(); + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_name!" + ); + return; + }; + unsafe { write_str_to_c(view_meta.name.as_str(), name, name_length) } } @@ -511,7 +609,12 @@ pub unsafe extern "C" fn cass_materialized_view_meta_name( pub unsafe extern "C" fn cass_materialized_view_meta_base_table( view_meta: CassBorrowedSharedPtr, ) -> CassBorrowedSharedPtr { - let view_meta = RefFFI::as_ref(view_meta).unwrap(); + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_base_table!" + ); + return RefFFI::null(); + }; let ptr = unsafe { RefFFI::weak_as_ptr(&view_meta.base_table) }; if RefFFI::is_null(&ptr) { @@ -535,7 +638,12 @@ pub unsafe extern "C" fn cass_materialized_view_meta_column( view_meta: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let view_meta = RefFFI::as_ref(view_meta).unwrap(); + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_column!" + ); + return RefFFI::null(); + }; match view_meta .view_metadata @@ -552,7 +660,13 @@ pub unsafe extern "C" fn cass_materialized_view_meta_column( pub unsafe extern "C" fn cass_materialized_view_meta_partition_key_count( view_meta: CassBorrowedSharedPtr, ) -> size_t { - let view_meta = RefFFI::as_ref(view_meta).unwrap(); + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_partition_key_count!" + ); + return 0; + }; + view_meta.view_metadata.partition_keys.len() as size_t } @@ -560,7 +674,12 @@ pub unsafe extern "C" fn cass_materialized_view_meta_partition_key( view_meta: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let view_meta = RefFFI::as_ref(view_meta).unwrap(); + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_partition_key!" + ); + return RefFFI::null(); + }; match view_meta.view_metadata.partition_keys.get(index as usize) { Some(column_name) => match view_meta.view_metadata.columns_metadata.get(column_name) { @@ -575,7 +694,13 @@ pub unsafe extern "C" fn cass_materialized_view_meta_partition_key( pub unsafe extern "C" fn cass_materialized_view_meta_clustering_key_count( view_meta: CassBorrowedSharedPtr, ) -> size_t { - let view_meta = RefFFI::as_ref(view_meta).unwrap(); + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_clustering_key_count!" + ); + return 0; + }; + view_meta.view_metadata.clustering_keys.len() as size_t } @@ -583,7 +708,12 @@ pub unsafe extern "C" fn cass_materialized_view_meta_clustering_key( view_meta: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let view_meta = RefFFI::as_ref(view_meta).unwrap(); + let Some(view_meta) = RefFFI::as_ref(view_meta) else { + tracing::error!( + "Provided null materialized view metadata pointer to cass_materialized_view_meta_clustering_key!" + ); + return RefFFI::null(); + }; match view_meta.view_metadata.clustering_keys.get(index as usize) { Some(column_name) => match view_meta.view_metadata.columns_metadata.get(column_name) { diff --git a/scylla-rust-wrapper/src/prepared.rs b/scylla-rust-wrapper/src/prepared.rs index f29bca5f..23a57d91 100644 --- a/scylla-rust-wrapper/src/prepared.rs +++ b/scylla-rust-wrapper/src/prepared.rs @@ -88,7 +88,11 @@ pub unsafe extern "C" fn cass_prepared_free( pub unsafe extern "C" fn cass_prepared_bind( prepared_raw: CassBorrowedSharedPtr, ) -> CassOwnedExclusivePtr { - let prepared: Arc<_> = ArcFFI::cloned_from_ptr(prepared_raw).unwrap(); + let Some(prepared) = ArcFFI::cloned_from_ptr(prepared_raw) else { + tracing::error!("Provided null prepared statement pointer to cass_prepared_bind!"); + return BoxFFI::null_mut(); + }; + let bound_values_size = prepared.statement.get_variable_col_specs().len(); // cloning prepared statement's arc, because creating CassStatement should not invalidate @@ -116,7 +120,12 @@ pub unsafe extern "C" fn cass_prepared_parameter_name( name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let prepared = ArcFFI::as_ref(prepared_raw).unwrap(); + let Some(prepared) = ArcFFI::as_ref(prepared_raw) else { + tracing::error!( + "Provided null prepared statement pointer to cass_prepared_parameter_name!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; match prepared .statement @@ -136,7 +145,12 @@ pub unsafe extern "C" fn cass_prepared_parameter_data_type( prepared_raw: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let prepared = ArcFFI::as_ref(prepared_raw).unwrap(); + let Some(prepared) = ArcFFI::as_ref(prepared_raw) else { + tracing::error!( + "Provided null prepared statement pointer to cass_prepared_parameter_data_type!" + ); + return ArcFFI::null(); + }; match prepared.variable_col_data_types.get(index as usize) { Some(dt) => ArcFFI::as_ptr(dt), @@ -158,7 +172,13 @@ pub unsafe extern "C" fn cass_prepared_parameter_data_type_by_name_n( name: *const c_char, name_length: size_t, ) -> CassBorrowedSharedPtr { - let prepared = ArcFFI::as_ref(prepared_raw).unwrap(); + let Some(prepared) = ArcFFI::as_ref(prepared_raw) else { + tracing::error!( + "Provided null prepared statement pointer to cass_prepared_parameter_data_type_by_name!" + ); + return ArcFFI::null(); + }; + let parameter_name = unsafe { ptr_to_cstr_n(name, name_length).expect("Prepared parameter name is not UTF-8") }; diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index a76e3c81..a88b86b1 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -482,11 +482,11 @@ pub unsafe extern "C" fn cass_result_free(result_raw: CassOwnedSharedPtr, ) -> cass_bool_t { - unsafe { result_has_more_pages(&result) } -} + let Some(result) = ArcFFI::as_ref(result.borrow()) else { + tracing::error!("Provided null result pointer to cass_result_has_more_pages!"); + return cass_false; + }; -unsafe fn result_has_more_pages(result: &CassBorrowedSharedPtr) -> cass_bool_t { - let result = ArcFFI::as_ref(result.borrow()).unwrap(); (!result.paging_state_response.finished()) as cass_bool_t } @@ -495,7 +495,10 @@ pub unsafe extern "C" fn cass_row_get_column<'result>( row_raw: CassBorrowedSharedPtr<'result, CassRow<'result>, CConst>, index: size_t, ) -> CassBorrowedSharedPtr<'result, CassValue<'result>, CConst> { - let row: &CassRow = RefFFI::as_ref(row_raw).unwrap(); + let Some(row) = RefFFI::as_ref(row_raw) else { + tracing::error!("Provided null row pointer to cass_row_get_column!"); + return RefFFI::null(); + }; let index_usize: usize = index.try_into().unwrap(); let column_value = match row.columns.get(index_usize) { @@ -523,7 +526,11 @@ pub unsafe extern "C" fn cass_row_get_column_by_name_n<'result>( name: *const c_char, name_length: size_t, ) -> CassBorrowedSharedPtr<'result, CassValue<'result>, CConst> { - let row_from_raw = RefFFI::as_ref(row).unwrap(); + let Some(row_from_raw) = RefFFI::as_ref(row) else { + tracing::error!("Provided null row pointer to cass_row_get_column_by_name_n!"); + return RefFFI::null(); + }; + let mut name_str = unsafe { ptr_to_cstr_n(name, name_length).unwrap() }; let mut is_case_sensitive = false; @@ -556,7 +563,11 @@ pub unsafe extern "C" fn cass_result_column_name( name: *mut *const c_char, name_length: *mut size_t, ) -> CassError { - let result_from_raw = ArcFFI::as_ref(result).unwrap(); + let Some(result_from_raw) = ArcFFI::as_ref(result) else { + tracing::error!("Provided null result pointer to cass_result_column_name!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let index_usize: usize = index.try_into().unwrap(); let CassResultKind::Rows(CassRowsResult { shared_data, .. }) = &result_from_raw.kind else { @@ -596,7 +607,11 @@ pub unsafe extern "C" fn cass_result_column_data_type( result: CassBorrowedSharedPtr, index: size_t, ) -> CassBorrowedSharedPtr { - let result_from_raw: &CassResult = ArcFFI::as_ref(result).unwrap(); + let Some(result_from_raw) = ArcFFI::as_ref(result) else { + tracing::error!("Provided null result pointer to cass_result_column_data_type!"); + return ArcFFI::null(); + }; + let index_usize: usize = index .try_into() .expect("Provided index is out of bounds. Max possible value is usize::MAX"); @@ -617,7 +632,11 @@ pub unsafe extern "C" fn cass_result_column_data_type( pub unsafe extern "C" fn cass_value_type( value: CassBorrowedSharedPtr, ) -> CassValueType { - let value_from_raw = RefFFI::as_ref(value).unwrap(); + let Some(value_from_raw) = RefFFI::as_ref(value) else { + tracing::error!("Provided null value pointer to cass_value_type!"); + return CassValueType::CASS_VALUE_TYPE_UNKNOWN; + }; + unsafe { cass_data_type_type(ArcFFI::as_ptr(value_from_raw.value_type)) } } @@ -625,17 +644,23 @@ pub unsafe extern "C" fn cass_value_type( pub unsafe extern "C" fn cass_value_data_type<'result>( value: CassBorrowedSharedPtr<'result, CassValue<'result>, CConst>, ) -> CassBorrowedSharedPtr<'result, CassDataType, CConst> { - let value_from_raw = RefFFI::as_ref(value).unwrap(); + let Some(value_from_raw) = RefFFI::as_ref(value) else { + tracing::error!("Provided null value pointer to cass_value_data_type!"); + return ArcFFI::null(); + }; ArcFFI::as_ptr(value_from_raw.value_type) } macro_rules! val_ptr_to_ref_ensure_non_null { - ($ptr:ident) => {{ + ($ptr:ident, $fn_name:expr) => {{ let maybe_ref = RefFFI::as_ref($ptr); match maybe_ref { Some(r) => r, - None => return CassError::CASS_ERROR_LIB_NULL_VALUE, + None => { + tracing::error!("Provided null value pointer to {}!", $fn_name); + return CassError::CASS_ERROR_LIB_NULL_VALUE; + } } }}; } @@ -645,7 +670,7 @@ pub unsafe extern "C" fn cass_value_get_float( value: CassBorrowedSharedPtr, output: *mut cass_float_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_float"); let f: f32 = match val.get_non_null() { Ok(v) => v, @@ -661,7 +686,7 @@ pub unsafe extern "C" fn cass_value_get_double( value: CassBorrowedSharedPtr, output: *mut cass_double_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_double"); let f: f64 = match val.get_non_null() { Ok(v) => v, @@ -677,7 +702,7 @@ pub unsafe extern "C" fn cass_value_get_bool( value: CassBorrowedSharedPtr, output: *mut cass_bool_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_bool"); let b: bool = match val.get_non_null() { Ok(v) => v, @@ -693,7 +718,7 @@ pub unsafe extern "C" fn cass_value_get_int8( value: CassBorrowedSharedPtr, output: *mut cass_int8_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_int8"); let i: i8 = match val.get_non_null() { Ok(v) => v, @@ -709,7 +734,7 @@ pub unsafe extern "C" fn cass_value_get_int16( value: CassBorrowedSharedPtr, output: *mut cass_int16_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_int16"); let i: i16 = match val.get_non_null() { Ok(v) => v, @@ -725,7 +750,7 @@ pub unsafe extern "C" fn cass_value_get_uint32( value: CassBorrowedSharedPtr, output: *mut cass_uint32_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_uint32"); let date: CqlDate = match val.get_non_null() { Ok(v) => v, @@ -741,7 +766,7 @@ pub unsafe extern "C" fn cass_value_get_int32( value: CassBorrowedSharedPtr, output: *mut cass_int32_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_int32"); let i: i32 = match val.get_non_null() { Ok(v) => v, @@ -757,7 +782,7 @@ pub unsafe extern "C" fn cass_value_get_int64( value: CassBorrowedSharedPtr, output: *mut cass_int64_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_int64"); let i: i64 = match val.value.typ() { ColumnType::Native(NativeType::BigInt) => match val.get_non_null::() { @@ -801,7 +826,7 @@ pub unsafe extern "C" fn cass_value_get_uuid( value: CassBorrowedSharedPtr, output: *mut CassUuid, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_uuid"); let uuid: Uuid = match val.value.typ() { ColumnType::Native(NativeType::Uuid) => match val.get_non_null::() { @@ -831,7 +856,7 @@ pub unsafe extern "C" fn cass_value_get_inet( value: CassBorrowedSharedPtr, output: *mut CassInet, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_inet"); let inet: IpAddr = match val.get_non_null() { Ok(v) => v, @@ -849,7 +874,7 @@ pub unsafe extern "C" fn cass_value_get_decimal( varint_size: *mut size_t, scale: *mut cass_int32_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_decimal"); let decimal: CqlDecimalBorrowed = match val.get_non_null() { Ok(v) => v, @@ -872,7 +897,7 @@ pub unsafe extern "C" fn cass_value_get_string( output: *mut *const c_char, output_size: *mut size_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_string"); // It seems that cpp driver doesn't check the type - you can call _get_string // on any type and get internal representation. I don't see how to do it easily in @@ -903,7 +928,7 @@ pub unsafe extern "C" fn cass_value_get_duration( days: *mut cass_int32_t, nanos: *mut cass_int64_t, ) -> CassError { - let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let val: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_duration"); let duration: CqlDuration = match val.get_non_null() { Ok(v) => v, @@ -925,7 +950,7 @@ pub unsafe extern "C" fn cass_value_get_bytes( output: *mut *const cass_byte_t, output_size: *mut size_t, ) -> CassError { - let value_from_raw: &CassValue = val_ptr_to_ref_ensure_non_null!(value); + let value_from_raw: &CassValue = val_ptr_to_ref_ensure_non_null!(value, "cass_value_get_bytes"); let bytes = match value_from_raw.get_bytes_non_null() { Ok(s) => s, @@ -944,7 +969,11 @@ pub unsafe extern "C" fn cass_value_get_bytes( pub unsafe extern "C" fn cass_value_is_null( value: CassBorrowedSharedPtr, ) -> cass_bool_t { - let val: &CassValue = RefFFI::as_ref(value).unwrap(); + let Some(val) = RefFFI::as_ref(value) else { + tracing::error!("Provided null value pointer to cass_value_is_null!"); + return cass_false; + }; + val.value.slice().is_none() as cass_bool_t } @@ -952,7 +981,10 @@ pub unsafe extern "C" fn cass_value_is_null( pub unsafe extern "C" fn cass_value_is_collection( value: CassBorrowedSharedPtr, ) -> cass_bool_t { - let val = RefFFI::as_ref(value).unwrap(); + let Some(val) = RefFFI::as_ref(value) else { + tracing::error!("Provided null value pointer to cass_value_is_collection!"); + return cass_false; + }; matches!( unsafe { val.value_type.get_unchecked() }.get_value_type(), @@ -966,7 +998,10 @@ pub unsafe extern "C" fn cass_value_is_collection( pub unsafe extern "C" fn cass_value_is_duration( value: CassBorrowedSharedPtr, ) -> cass_bool_t { - let val = RefFFI::as_ref(value).unwrap(); + let Some(val) = RefFFI::as_ref(value) else { + tracing::error!("Provided null value pointer to cass_value_is_duration!"); + return cass_false; + }; (unsafe { val.value_type.get_unchecked() }.get_value_type() == CassValueType::CASS_VALUE_TYPE_DURATION) as cass_bool_t @@ -976,7 +1011,10 @@ pub unsafe extern "C" fn cass_value_is_duration( pub unsafe extern "C" fn cass_value_item_count( collection: CassBorrowedSharedPtr, ) -> size_t { - let val = RefFFI::as_ref(collection).unwrap(); + let Some(val) = RefFFI::as_ref(collection) else { + tracing::error!("Provided null value pointer to cass_value_item_count!"); + return 0; + }; val.value.item_count().unwrap_or(0) as size_t } @@ -985,7 +1023,10 @@ pub unsafe extern "C" fn cass_value_item_count( pub unsafe extern "C" fn cass_value_primary_sub_type( collection: CassBorrowedSharedPtr, ) -> CassValueType { - let val = RefFFI::as_ref(collection).unwrap(); + let Some(val) = RefFFI::as_ref(collection) else { + tracing::error!("Provided null value pointer to cass_value_primary_sub_type!"); + return CassValueType::CASS_VALUE_TYPE_UNKNOWN; + }; match unsafe { val.value_type.get_unchecked() } { CassDataTypeInner::List { @@ -1006,7 +1047,10 @@ pub unsafe extern "C" fn cass_value_primary_sub_type( pub unsafe extern "C" fn cass_value_secondary_sub_type( collection: CassBorrowedSharedPtr, ) -> CassValueType { - let val = RefFFI::as_ref(collection).unwrap(); + let Some(val) = RefFFI::as_ref(collection) else { + tracing::error!("Provided null value pointer to cass_value_secondary_sub_type!"); + return CassValueType::CASS_VALUE_TYPE_UNKNOWN; + }; match unsafe { val.value_type.get_unchecked() } { CassDataTypeInner::Map { @@ -1021,7 +1065,10 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type( pub unsafe extern "C" fn cass_result_row_count( result_raw: CassBorrowedSharedPtr, ) -> size_t { - let result = ArcFFI::as_ref(result_raw).unwrap(); + let Some(result) = ArcFFI::as_ref(result_raw) else { + tracing::error!("Provided null result pointer to cass_result_row_count!"); + return 0; + }; let CassResultKind::Rows(CassRowsResult { shared_data, .. }) = &result.kind else { return 0; @@ -1034,7 +1081,10 @@ pub unsafe extern "C" fn cass_result_row_count( pub unsafe extern "C" fn cass_result_column_count( result_raw: CassBorrowedSharedPtr, ) -> size_t { - let result = ArcFFI::as_ref(result_raw).unwrap(); + let Some(result) = ArcFFI::as_ref(result_raw) else { + tracing::error!("Provided null result pointer to cass_result_column_count!"); + return 0; + }; let CassResultKind::Rows(CassRowsResult { shared_data, .. }) = &result.kind else { return 0; @@ -1047,7 +1097,10 @@ pub unsafe extern "C" fn cass_result_column_count( pub unsafe extern "C" fn cass_result_first_row( result_raw: CassBorrowedSharedPtr, ) -> CassBorrowedSharedPtr { - let result = ArcFFI::as_ref(result_raw).unwrap(); + let Some(result) = ArcFFI::as_ref(result_raw) else { + tracing::error!("Provided null result pointer to cass_result_first_row!"); + return RefFFI::null(); + }; let CassResultKind::Rows(CassRowsResult { first_row, .. }) = &result.kind else { return RefFFI::null(); @@ -1066,12 +1119,15 @@ pub unsafe extern "C" fn cass_result_paging_state_token( paging_state: *mut *const c_char, paging_state_size: *mut size_t, ) -> CassError { - if unsafe { result_has_more_pages(&result) } == cass_false { + let Some(result_from_raw) = ArcFFI::as_ref(result.borrow()) else { + tracing::error!("Provided null result pointer to cass_result_paging_state_token!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + if unsafe { cass_result_has_more_pages(result.borrow()) } == cass_false { return CassError::CASS_ERROR_LIB_NO_PAGING_STATE; } - let result_from_raw = ArcFFI::as_ref(result).unwrap(); - match &result_from_raw.paging_state_response { PagingStateResponse::HasMorePages { state } => match state.as_bytes_slice() { Some(result_paging_state) => unsafe { diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 91a75cdc..2612c9e2 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -154,8 +154,14 @@ pub unsafe extern "C" fn cass_session_connect( session_raw: CassBorrowedSharedPtr, cluster_raw: CassBorrowedSharedPtr, ) -> CassOwnedSharedPtr { - let session_opt = ArcFFI::cloned_from_ptr(session_raw).unwrap(); - let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw).unwrap(); + let Some(session_opt) = ArcFFI::cloned_from_ptr(session_raw) else { + tracing::error!("Provided null session pointer to cass_session_connect!"); + return ArcFFI::null(); + }; + let Some(cluster) = BoxFFI::as_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_session_connect!"); + return ArcFFI::null(); + }; CassSessionInner::connect(session_opt, cluster, None) } @@ -176,8 +182,14 @@ pub unsafe extern "C" fn cass_session_connect_keyspace_n( keyspace: *const c_char, keyspace_length: size_t, ) -> CassOwnedSharedPtr { - let session_opt = ArcFFI::cloned_from_ptr(session_raw).unwrap(); - let cluster: &CassCluster = BoxFFI::as_ref(cluster_raw).unwrap(); + let Some(session_opt) = ArcFFI::cloned_from_ptr(session_raw) else { + tracing::error!("Provided null session pointer to cass_session_connect_keyspace_n!"); + return ArcFFI::null(); + }; + let Some(cluster) = BoxFFI::as_ref(cluster_raw) else { + tracing::error!("Provided null cluster pointer to cass_session_connect_keyspace_n!"); + return ArcFFI::null(); + }; let keyspace = unsafe { ptr_to_cstr_n(keyspace, keyspace_length) }.map(ToOwned::to_owned); CassSessionInner::connect(session_opt, cluster, keyspace) @@ -188,8 +200,15 @@ pub unsafe extern "C" fn cass_session_execute_batch( session_raw: CassBorrowedSharedPtr, batch_raw: CassBorrowedSharedPtr, ) -> CassOwnedSharedPtr { - let session_opt = ArcFFI::cloned_from_ptr(session_raw).unwrap(); - let batch_from_raw = BoxFFI::as_ref(batch_raw).unwrap(); + let Some(session_opt) = ArcFFI::cloned_from_ptr(session_raw) else { + tracing::error!("Provided null session pointer to cass_session_execute_batch!"); + return ArcFFI::null(); + }; + let Some(batch_from_raw) = BoxFFI::as_ref(batch_raw) else { + tracing::error!("Provided null batch pointer to cass_session_execute_batch!"); + return ArcFFI::null(); + }; + let mut state = batch_from_raw.state.clone(); let request_timeout_ms = batch_from_raw.batch_request_timeout_ms; @@ -254,10 +273,17 @@ pub unsafe extern "C" fn cass_session_execute( session_raw: CassBorrowedSharedPtr, statement_raw: CassBorrowedSharedPtr, ) -> CassOwnedSharedPtr { - let session_opt = ArcFFI::cloned_from_ptr(session_raw).unwrap(); + let Some(session_opt) = ArcFFI::cloned_from_ptr(session_raw) else { + tracing::error!("Provided null session pointer to cass_session_execute!"); + return ArcFFI::null(); + }; // DO NOT refer to `statement_opt` inside the async block, as I've done just to face a segfault. - let statement_opt = BoxFFI::as_ref(statement_raw).unwrap(); + let Some(statement_opt) = BoxFFI::as_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_session_execute!"); + return ArcFFI::null(); + }; + let paging_state = statement_opt.paging_state.clone(); let paging_enabled = statement_opt.paging_enabled; let request_timeout_ms = statement_opt.request_timeout_ms; @@ -389,8 +415,15 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing( cass_session: CassBorrowedSharedPtr, statement: CassBorrowedSharedPtr, ) -> CassOwnedSharedPtr { - let session = ArcFFI::cloned_from_ptr(cass_session).unwrap(); - let cass_statement = BoxFFI::as_ref(statement).unwrap(); + let Some(session) = ArcFFI::cloned_from_ptr(cass_session) else { + tracing::error!("Provided null session pointer to cass_session_prepare_from_existing!"); + return ArcFFI::null(); + }; + let Some(cass_statement) = BoxFFI::as_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_session_prepare_from_existing!"); + return ArcFFI::null(); + }; + let statement = cass_statement.statement.clone(); CassFuture::make_raw(async move { @@ -434,6 +467,11 @@ pub unsafe extern "C" fn cass_session_prepare_n( query: *const c_char, query_length: size_t, ) -> CassOwnedSharedPtr { + let Some(cass_session) = ArcFFI::cloned_from_ptr(cass_session_raw) else { + tracing::error!("Provided null session pointer to cass_session_prepare_n!"); + return ArcFFI::null(); + }; + let query_str = unsafe { ptr_to_cstr_n(query, query_length) } // Apparently nullptr denotes an empty statement string. // It seems to be intended (for some weird reason, why not save a round-trip???) @@ -441,7 +479,6 @@ pub unsafe extern "C" fn cass_session_prepare_n( // There is a test for this: `NullStringApiArgsTest.Integration_Cassandra_PrepareNullQuery`. .unwrap_or_default(); let query = Statement::new(query_str.to_string()); - let cass_session = ArcFFI::cloned_from_ptr(cass_session_raw).unwrap(); CassFuture::make_raw(async move { let session_guard = cass_session.read().await; @@ -476,7 +513,10 @@ pub unsafe extern "C" fn cass_session_free(session_raw: CassOwnedSharedPtr, ) -> CassOwnedSharedPtr { - let session_opt = ArcFFI::cloned_from_ptr(session).unwrap(); + let Some(session_opt) = ArcFFI::cloned_from_ptr(session) else { + tracing::error!("Provided null session pointer to cass_session_close!"); + return ArcFFI::null(); + }; CassFuture::make_raw(async move { let mut session_guard = session_opt.write().await; @@ -497,7 +537,10 @@ pub unsafe extern "C" fn cass_session_close( pub unsafe extern "C" fn cass_session_get_client_id( session: CassBorrowedSharedPtr, ) -> CassUuid { - let cass_session = ArcFFI::as_ref(session).unwrap(); + let Some(cass_session) = ArcFFI::as_ref(session) else { + tracing::error!("Provided null session pointer to cass_session_get_client_id!"); + return uuid::Uuid::nil().into(); + }; let client_id: uuid::Uuid = cass_session.blocking_read().as_ref().unwrap().client_id; client_id.into() diff --git a/scylla-rust-wrapper/src/ssl.rs b/scylla-rust-wrapper/src/ssl.rs index 03dd12ce..b138d08e 100644 --- a/scylla-rust-wrapper/src/ssl.rs +++ b/scylla-rust-wrapper/src/ssl.rs @@ -116,7 +116,11 @@ pub unsafe extern "C" fn cass_ssl_add_trusted_cert_n( cert: *const c_char, cert_length: size_t, ) -> CassError { - let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); + let Some(ssl) = ArcFFI::cloned_from_ptr(ssl) else { + tracing::error!("Provided null ssl pointer to cass_ssl_add_trusted_cert_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let bio = unsafe { BIO_new_mem_buf(cert as *const c_void, cert_length.try_into().unwrap()) }; if bio.is_null() { @@ -151,7 +155,10 @@ pub unsafe extern "C" fn cass_ssl_set_verify_flags( ssl: CassBorrowedSharedPtr, flags: i32, ) { - let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); + let Some(ssl) = ArcFFI::cloned_from_ptr(ssl) else { + tracing::error!("Provided null ssl pointer to cass_ssl_set_verify_flags!"); + return; + }; match flags { CASS_SSL_VERIFY_NONE => unsafe { @@ -196,7 +203,11 @@ pub unsafe extern "C" fn cass_ssl_set_cert_n( cert: *const c_char, cert_length: size_t, ) -> CassError { - let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); + let Some(ssl) = ArcFFI::cloned_from_ptr(ssl) else { + tracing::error!("Provided null ssl pointer to cass_ssl_set_cert_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let bio = unsafe { BIO_new_mem_buf(cert as *const c_void, cert_length.try_into().unwrap()) }; if bio.is_null() { @@ -295,7 +306,11 @@ pub unsafe extern "C" fn cass_ssl_set_private_key_n( password: *mut c_char, _password_length: size_t, ) -> CassError { - let ssl = ArcFFI::cloned_from_ptr(ssl).unwrap(); + let Some(ssl) = ArcFFI::cloned_from_ptr(ssl) else { + tracing::error!("Provided null ssl pointer to cass_ssl_set_private_key_n!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let bio = unsafe { BIO_new_mem_buf(key as *const c_void, key_length.try_into().unwrap()) }; if bio.is_null() { diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index c767e0c4..5aec5029 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -309,10 +309,15 @@ pub unsafe extern "C" fn cass_statement_set_consistency( statement: CassBorrowedExclusivePtr, consistency: CassConsistency, ) -> CassError { + let Some(statement) = BoxFFI::as_mut_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_statement_set_consistency!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let consistency_opt = get_consistency_from_cass_consistency(consistency); if let Some(consistency) = consistency_opt { - match &mut BoxFFI::as_mut_ref(statement).unwrap().statement { + match &mut statement.statement { BoundStatement::Simple(inner) => inner.query.set_consistency(consistency), BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) .statement @@ -328,7 +333,11 @@ pub unsafe extern "C" fn cass_statement_set_paging_size( statement_raw: CassBorrowedExclusivePtr, page_size: c_int, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement_raw).unwrap(); + let Some(statement) = BoxFFI::as_mut_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_statement_set_paging_size!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + if page_size <= 0 { // Cpp driver sets the page size flag only for positive page size provided by user. statement.paging_enabled = false; @@ -350,8 +359,14 @@ pub unsafe extern "C" fn cass_statement_set_paging_state( statement: CassBorrowedExclusivePtr, result: CassBorrowedSharedPtr, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement).unwrap(); - let result = ArcFFI::as_ref(result).unwrap(); + let Some(statement) = BoxFFI::as_mut_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_statement_set_paging_state!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let Some(result) = ArcFFI::as_ref(result) else { + tracing::error!("Provided null result pointer to cass_statement_set_paging_state!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; match &result.paging_state_response { PagingStateResponse::HasMorePages { state } => statement.paging_state.clone_from(state), @@ -366,7 +381,12 @@ pub unsafe extern "C" fn cass_statement_set_paging_state_token( paging_state: *const c_char, paging_state_size: size_t, ) -> CassError { - let statement_from_raw = BoxFFI::as_mut_ref(statement).unwrap(); + let Some(statement_from_raw) = BoxFFI::as_mut_ref(statement) else { + tracing::error!( + "Provided null statement pointer to cass_statement_set_paging_state_token!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; if paging_state.is_null() { statement_from_raw.paging_state = PagingState::start(); @@ -385,7 +405,12 @@ pub unsafe extern "C" fn cass_statement_set_is_idempotent( statement_raw: CassBorrowedExclusivePtr, is_idempotent: cass_bool_t, ) -> CassError { - match &mut BoxFFI::as_mut_ref(statement_raw).unwrap().statement { + let Some(statement) = BoxFFI::as_mut_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_statement_set_is_idempotent!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match &mut statement.statement { BoundStatement::Simple(inner) => inner.query.set_is_idempotent(is_idempotent != 0), BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) .statement @@ -400,7 +425,12 @@ pub unsafe extern "C" fn cass_statement_set_tracing( statement_raw: CassBorrowedExclusivePtr, enabled: cass_bool_t, ) -> CassError { - match &mut BoxFFI::as_mut_ref(statement_raw).unwrap().statement { + let Some(statement) = BoxFFI::as_mut_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_statement_set_tracing!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match &mut statement.statement { BoundStatement::Simple(inner) => inner.query.set_tracing(enabled != 0), BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) .statement @@ -415,6 +445,11 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( statement: CassBorrowedExclusivePtr, retry_policy: CassBorrowedSharedPtr, ) -> CassError { + let Some(statement) = BoxFFI::as_mut_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_statement_set_retry_policy!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + let maybe_arced_retry_policy: Option> = ArcFFI::as_ref(retry_policy).map(|policy| match policy { CassRetryPolicy::DefaultRetryPolicy(default) => { @@ -424,7 +459,7 @@ pub unsafe extern "C" fn cass_statement_set_retry_policy( CassRetryPolicy::DowngradingConsistencyRetryPolicy(downgrading) => downgrading.clone(), }); - match &mut BoxFFI::as_mut_ref(statement).unwrap().statement { + match &mut statement.statement { BoundStatement::Simple(inner) => inner.query.set_retry_policy(maybe_arced_retry_policy), BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) .statement @@ -439,6 +474,13 @@ pub unsafe extern "C" fn cass_statement_set_serial_consistency( statement: CassBorrowedExclusivePtr, serial_consistency: CassConsistency, ) -> CassError { + let Some(statement) = BoxFFI::as_mut_ref(statement) else { + tracing::error!( + "Provided null statement pointer to cass_statement_set_serial_consistency!" + ); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + // cpp-driver doesn't validate passed value in any way. // If it is an incorrect serial-consistency value then it will be set // and sent as-is. @@ -453,7 +495,7 @@ pub unsafe extern "C" fn cass_statement_set_serial_consistency( _ => return CassError::CASS_ERROR_LIB_BAD_PARAMS, }; - match &mut BoxFFI::as_mut_ref(statement).unwrap().statement { + match &mut statement.statement { BoundStatement::Simple(inner) => inner.query.set_serial_consistency(Some(consistency)), BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) .statement @@ -485,7 +527,12 @@ pub unsafe extern "C" fn cass_statement_set_timestamp( statement: CassBorrowedExclusivePtr, timestamp: cass_int64_t, ) -> CassError { - match &mut BoxFFI::as_mut_ref(statement).unwrap().statement { + let Some(statement) = BoxFFI::as_mut_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_statement_set_timestamp!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + + match &mut statement.statement { BoundStatement::Simple(inner) => inner.query.set_timestamp(Some(timestamp)), BoundStatement::Prepared(inner) => Arc::make_mut(&mut inner.statement) .statement @@ -500,6 +547,11 @@ pub unsafe extern "C" fn cass_statement_set_request_timeout( statement: CassBorrowedExclusivePtr, timeout_ms: cass_uint64_t, ) -> CassError { + let Some(statement_from_raw) = BoxFFI::as_mut_ref(statement) else { + tracing::error!("Provided null statement pointer to cass_statement_set_request_timeout!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + // The maximum duration for a sleep is 68719476734 milliseconds (approximately 2.2 years). // Note: this is limited by tokio::time:timout // https://github.com/tokio-rs/tokio/blob/4b1c4801b1383800932141d0f6508d5b3003323e/tokio/src/time/driver/wheel/mod.rs#L44-L50 @@ -508,7 +560,6 @@ pub unsafe extern "C" fn cass_statement_set_request_timeout( return CassError::CASS_ERROR_LIB_BAD_PARAMS; } - let statement_from_raw = BoxFFI::as_mut_ref(statement).unwrap(); statement_from_raw.request_timeout_ms = Some(timeout_ms); CassError::CASS_OK @@ -519,7 +570,11 @@ pub unsafe extern "C" fn cass_statement_reset_parameters( statement_raw: CassBorrowedExclusivePtr, count: size_t, ) -> CassError { - let statement = BoxFFI::as_mut_ref(statement_raw).unwrap(); + let Some(statement) = BoxFFI::as_mut_ref(statement_raw) else { + tracing::error!("Provided null statement pointer to cass_statement_reset_parameters!"); + return CassError::CASS_ERROR_LIB_BAD_PARAMS; + }; + statement.reset_bound_values(count as usize); CassError::CASS_OK diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 76554251..81f831c5 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -78,7 +78,11 @@ pub unsafe extern "C" fn cass_tuple_new( unsafe extern "C" fn cass_tuple_new_from_data_type( data_type: CassBorrowedSharedPtr, ) -> CassOwnedExclusivePtr { - let data_type = ArcFFI::cloned_from_ptr(data_type).unwrap(); + let Some(data_type) = ArcFFI::cloned_from_ptr(data_type) else { + tracing::error!("Provided null data type pointer to cass_tuple_new_from_data_type!"); + return BoxFFI::null_mut(); + }; + let item_count = match unsafe { data_type.get_unchecked() } { CassDataTypeInner::Tuple(v) => v.len(), _ => return BoxFFI::null_mut(), @@ -98,7 +102,12 @@ unsafe extern "C" fn cass_tuple_free(tuple: CassOwnedExclusivePtr, ) -> CassBorrowedSharedPtr { - match &BoxFFI::as_ref(tuple).unwrap().data_type { + let Some(tuple) = BoxFFI::as_ref(tuple) else { + tracing::error!("Provided null tuple pointer to cass_tuple_data_type!"); + return ArcFFI::null(); + }; + + match &tuple.data_type { Some(t) => ArcFFI::as_ptr(t), None => ArcFFI::as_ptr(&UNTYPED_TUPLE_TYPE), } diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index d4d93a89..a232e903 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -87,7 +87,10 @@ impl From<&CassUserType> for CassCqlValue { pub unsafe extern "C" fn cass_user_type_new_from_data_type( data_type_raw: CassBorrowedSharedPtr, ) -> CassOwnedExclusivePtr { - let data_type = ArcFFI::cloned_from_ptr(data_type_raw).unwrap(); + let Some(data_type) = ArcFFI::cloned_from_ptr(data_type_raw) else { + tracing::error!("Provided null data type pointer to cass_user_type_new_from_data_type!"); + return BoxFFI::null_mut(); + }; match unsafe { data_type.get_unchecked() } { CassDataTypeInner::UDT(udt_data_type) => { @@ -109,7 +112,12 @@ pub unsafe extern "C" fn cass_user_type_free(user_type: CassOwnedExclusivePtr, ) -> CassBorrowedSharedPtr { - ArcFFI::as_ptr(&BoxFFI::as_ref(user_type).unwrap().data_type) + let Some(user_type) = BoxFFI::as_ref(user_type) else { + tracing::error!("Provided null user type pointer to cass_user_type_data_type!"); + return ArcFFI::null(); + }; + + ArcFFI::as_ptr(&user_type.data_type) } prepare_binders_macro!(@index_and_name CassUserType, diff --git a/scylla-rust-wrapper/src/uuid.rs b/scylla-rust-wrapper/src/uuid.rs index c30ca34d..e5d686bc 100644 --- a/scylla-rust-wrapper/src/uuid.rs +++ b/scylla-rust-wrapper/src/uuid.rs @@ -138,7 +138,10 @@ pub unsafe extern "C" fn cass_uuid_gen_time( uuid_gen: CassBorrowedExclusivePtr, output: *mut CassUuid, ) { - let uuid_gen = BoxFFI::as_mut_ref(uuid_gen).unwrap(); + let Some(uuid_gen) = BoxFFI::as_mut_ref(uuid_gen) else { + tracing::error!("Provided null uuid generator pointer to cass_uuid_gen_time!"); + return; + }; let uuid = CassUuid { time_and_version: set_version(monotonic_timestamp(&mut uuid_gen.last_timestamp), 1), @@ -168,7 +171,10 @@ pub unsafe extern "C" fn cass_uuid_gen_from_time( timestamp: cass_uint64_t, output: *mut CassUuid, ) { - let uuid_gen = BoxFFI::as_mut_ref(uuid_gen).unwrap(); + let Some(uuid_gen) = BoxFFI::as_mut_ref(uuid_gen) else { + tracing::error!("Provided null uuid generator pointer to cass_uuid_gen_from_time!"); + return; + }; let uuid = CassUuid { time_and_version: set_version(from_unix_timestamp(timestamp), 1),