From 87806d9ab70367f0cbd9f9c1fa4f3a259341b8b9 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Tue, 24 Sep 2024 14:26:37 +0800 Subject: [PATCH 01/23] feat(expressions-compatibility) add FFI interface to get operators along with fields The FFI interface introducing by this commit has corresponding Go wrapping: ``` func ValidateExpression(atc string, s *Schema) (bool, []string, int64) { atcC := unsafe.Pointer(C.CString(atc)) defer C.free(atcC) errLen := C.ulong(1024) errBuf := [1024]C.uchar{} expr := C.expression_validate((*C.uchar)(atcC), s.s, &errBuf[0], &errLen) defer C.expression_validate_free_result(expr) if expr == nil { fmt.Println("Error: ", string(errBuf[:errLen])) return false, nil, 0 } validate := bool(expr.validate) operators := int64(expr.operators) flds := make([]string, expr.fields_total) flds_slice := unsafe.Slice(expr.fields, expr.fields_total) for i := range flds { flds[i] = C.GoString((*C.char)(unsafe.Pointer(flds_slice[i]))) } return validate, flds, operators } ``` --- Cargo.lock | 7 ++ Cargo.toml | 4 +- cbindgen.toml | 7 ++ src/ast.rs | 40 ++++++++++ src/ffi.rs | 215 ++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 272 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index fb34bf1f..dfffa828 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,6 +27,7 @@ checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" name = "atc-router" version = "1.6.1" dependencies = [ + "bitflags", "cidr", "criterion", "fnv", @@ -45,6 +46,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "block-buffer" version = "0.10.4" diff --git a/Cargo.toml b/Cargo.toml index 5782964c..5aa16258 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ regex = "1" serde = { version = "1.0", features = ["derive"], optional = true } serde_regex = { version = "1.1", optional = true } fnv = "1" +bitflags = "2.6.0" [dev-dependencies] criterion = "0.*" @@ -29,7 +30,7 @@ criterion = "0.*" crate-type = ["lib", "cdylib", "staticlib"] [features] -default = ["ffi"] +default = ["ffi", "expr_validation"] ffi = [] serde = ["cidr/serde", "dep:serde", "dep:serde_regex"] @@ -52,3 +53,4 @@ harness = false [[bench]] name = "build" harness = false +expr_validation = [] diff --git a/cbindgen.toml b/cbindgen.toml index a2cbb7e1..bfc55038 100644 --- a/cbindgen.toml +++ b/cbindgen.toml @@ -6,3 +6,10 @@ prefix_with_name = true [defines] "feature = ffi" = "DEFINE_ATC_ROUTER_FFI" +"feature = expr_validation" = "DEFINE_ATC_ROUTER_EXPR_VALIDATION" + +[macro_expansion] +bitflags = true + +[export] +include = ["BinaryOperatorFlags"] \ No newline at end of file diff --git a/src/ast.rs b/src/ast.rs index 42a04752..37257e13 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -45,6 +45,46 @@ pub enum BinaryOperator { Contains, // contains } +#[cfg(feature = "expr_validation")] +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + #[repr(C)] + pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ { + const EQUALS = 1 << 0; + const NOT_EQUALS = 1 << 1; + const REGEX = 1 << 2; + const PREFIX = 1 << 3; + const POSTFIX = 1 << 4; + const GREATER = 1 << 5; + const GREATER_OR_EQUAL = 1 << 6; + const LESS = 1 << 7; + const LESS_OR_EQUAL = 1 << 8; + const IN = 1 << 9; + const NOT_IN = 1 << 10; + const CONTAINS = 1 << 11; + } +} + +#[cfg(feature = "expr_validation")] +impl From<&BinaryOperator> for BinaryOperatorFlags { + fn from(op: &BinaryOperator) -> Self { + match op { + BinaryOperator::Equals => Self::EQUALS, + BinaryOperator::NotEquals => Self::NOT_EQUALS, + BinaryOperator::Regex => Self::REGEX, + BinaryOperator::Prefix => Self::PREFIX, + BinaryOperator::Postfix => Self::POSTFIX, + BinaryOperator::Greater => Self::GREATER, + BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL, + BinaryOperator::Less => Self::LESS, + BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL, + BinaryOperator::In => Self::IN, + BinaryOperator::NotIn => Self::NOT_IN, + BinaryOperator::Contains => Self::CONTAINS, + } + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub enum Value { diff --git a/src/ffi.rs b/src/ffi.rs index 2a4aa8c2..b86c3deb 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -578,6 +578,164 @@ pub unsafe extern "C" fn context_get_result( .unwrap() } +#[cfg(feature = "expr_validation")] +#[derive(Debug)] +#[repr(C)] +pub struct ExpressionValidationResult { + validate: bool, // if validate is false, then none of the following fields are valid + fields: *mut *mut c_char, + fields_total: usize, + operators: u64, +} + +/// Validate the ATC expression with the schema. +/// +/// # Arguments +/// +/// - `atc`: the C-style string representing the ATC expression. +/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. +/// - `errbuf`: a buffer to store the error message. +/// - `errbuf_len`: a pointer to the length of the error message buffer. +/// +/// # Returns +/// +/// Returns a pointer of `ExpressionValidationResult`. +/// If the expression is not valid, the `validate` field will be `false`, +/// and the error message will be stored in the `errbuf`, +/// and the length of the error message will be stored in `errbuf_len`. +/// +/// # Panics +/// +/// This function will panic when: +/// +/// - `atc` doesn't point to a valid C-style string. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `atc` must be a valid pointer to a C-style string, must be properly aligned, +/// and must not have '\0' in the middle. +/// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, +/// and it must be properly aligned. +/// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, +/// and it must be properly aligned. +#[cfg(feature = "expr_validation")] +#[no_mangle] +pub unsafe extern "C" fn expression_validate( + atc: *const u8, + schema: &Schema, + errbuf: *mut u8, + errbuf_len: *mut usize, +) -> *mut ExpressionValidationResult { + use std::collections::HashMap; + + use crate::ast::{BinaryOperatorFlags, Expression, LogicalExpression}; + use crate::parser::parse; + use crate::semantics::FieldCounter; + use crate::semantics::Validate; + + let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); + let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); + let mut validation_result = Box::new(ExpressionValidationResult { + validate: false, + fields: std::ptr::null_mut(), + fields_total: 0, + operators: 0, + }); + + // Parse the expression + let result = parse(atc).map_err(|e| e.to_string()); + if let Err(e) = result { + let errlen = min(e.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); + *errbuf_len = errlen; + validation_result.validate = false; + return Box::into_raw(validation_result); + } + // Unwrap is safe since we've already checked for error + let ast = result.unwrap(); + + // Validate expression with schema + if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) { + let errlen = min(e.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); + *errbuf_len = errlen; + validation_result.validate = false; + return Box::into_raw(validation_result); + } + + // Get used fields + let mut expr_fields = HashMap::new(); + ast.add_to_counter(&mut expr_fields); + let fields_count = expr_fields.len(); + let mut fields = Vec::<*const c_char>::with_capacity(fields_count); + + for k in expr_fields.into_keys() { + let ffi_string = ffi::CString::new(k).unwrap(); + let ptr = ffi_string.into_raw(); // Leak the CString + fields.push(ptr); + } + + // Get used operators + let mut ops = BinaryOperatorFlags::empty(); + fn visit(expr: &Expression, ops: &mut BinaryOperatorFlags) { + match expr { + Expression::Logical(logic_expression) => match logic_expression.as_ref() { + LogicalExpression::And(lhs, rhs) => { + visit(lhs, ops); + visit(rhs, ops); + } + LogicalExpression::Or(lhs, rhs) => { + visit(lhs, ops); + visit(rhs, ops); + } + LogicalExpression::Not(rhs) => { + visit(rhs, ops); + } + }, + Expression::Predicate(predict) => { + let op = BinaryOperatorFlags::from(&predict.op); + ops.insert(op); + } + } + } + visit(&ast, &mut ops); + + validation_result.validate = true; + validation_result.operators = ops.bits(); + let boxed_fields = fields.into_boxed_slice(); + let raw_boxed_fields = Box::into_raw(boxed_fields); // Leak the Box + + validation_result.fields = raw_boxed_fields.cast(); + validation_result.fields_total = raw_boxed_fields.len(); + + Box::into_raw(validation_result) // Leak the Box +} + +/// Deallocate the ExpressionValidationResult object. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `result` must be a valid pointer returned by [`expression_validate`]. +#[cfg(feature = "expr_validation")] +#[no_mangle] +pub unsafe extern "C" fn expression_validate_free_result(result: *mut ExpressionValidationResult) { + let result = Box::from_raw(result); + let slice = std::slice::from_raw_parts_mut(result.fields, result.fields_total); + let boxed_fields = Box::from_raw(slice); + for ptr in boxed_fields.into_vec() { + let _ = ffi::CString::from_raw(ptr); // Drop the leaked CString + } // Drop the Box + drop(result); // Drop the Box +} + #[cfg(test)] mod tests { use super::*; @@ -627,4 +785,61 @@ mod tests { assert!(errbuf_len < ERR_BUF_MAX_LEN); } } + + #[cfg(feature = "expr_validation")] + #[test] + fn test_expression_validate() { + use crate::ast::BinaryOperatorFlags; + unsafe { + let mut schema = Schema::default(); + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; + let atc = ffi::CString::new(atc).unwrap(); + let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; + let mut errbuf_len = ERR_BUF_MAX_LEN; + + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + schema.add_field("http.path", Type::String); + + let result = expression_validate( + atc.as_bytes().as_ptr(), + &schema, + errbuf.as_mut_ptr(), + &mut errbuf_len, + ); + + assert!((*result).validate, "Validation failed"); + assert_eq!((*result).fields_total, 4, "Fields count mismatch"); + assert_eq!( + (*result).operators, + (BinaryOperatorFlags::EQUALS + | BinaryOperatorFlags::REGEX + | BinaryOperatorFlags::IN + | BinaryOperatorFlags::NOT_IN + | BinaryOperatorFlags::CONTAINS) + .bits(), + "Operators mismatch" + ); + let mut fields = Vec::::with_capacity((*result).fields_total); + for i in 0..(*result).fields_total { + let field = (*result).fields.add(i); + let field = ffi::CStr::from_ptr(*field).to_str().unwrap(); + fields.push(field.to_string()); + } + fields.sort(); + assert_eq!( + fields, + vec![ + "http.path".to_string(), + "net.dst.port".to_string(), + "net.protocol".to_string(), + "net.src.ip".to_string() + ], + "Fields mismatch" + ); + + expression_validate_free_result(result); + } + } } From 7a3931648080b564b95a22ee5fb3b6ca84ee4e17 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Tue, 24 Sep 2024 17:06:58 +0800 Subject: [PATCH 02/23] feat(expressions-compatibility) refactor to not allocate memory from Rust An example go binding for the FFI interface introduced by this commit: ```Go func ValidateExpression(atc string, s *Schema) (bool, []string, uint64, error) { atcC := unsafe.Pointer(C.CString(atc)) defer C.free(atcC) errLen := C.ulong(1024) errBuf := [1024]C.uchar{} fieldsLen := C.ulong(1024) fieldsBuf := [1024]C.uchar{} fieldsTotal := C.ulong(0) operatorsC := C.uint64_t(0) result := C.expression_validate((*C.uchar)(atcC), s.s, &fieldsBuf[0], &fieldsLen, &fieldsTotal, &operatorsC, &errBuf[0], &errLen) if bool(result) == false { return false, nil, 0, fmt.Errorf(string(errBuf[:errLen])) } operators := uint64(operatorsC) flds := make([]string, uintptr(fieldsTotal)) p := 0 for i := range flds { flds[i] = C.GoString((*C.char)(unsafe.Pointer(&fieldsBuf[p]))) p += len(flds[i]) + 1 } return true, flds, operators, nil } ``` --- src/ffi.rs | 128 ++++++++++++++++++++++++++--------------------------- 1 file changed, 62 insertions(+), 66 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index b86c3deb..02e82fca 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -578,16 +578,6 @@ pub unsafe extern "C" fn context_get_result( .unwrap() } -#[cfg(feature = "expr_validation")] -#[derive(Debug)] -#[repr(C)] -pub struct ExpressionValidationResult { - validate: bool, // if validate is false, then none of the following fields are valid - fields: *mut *mut c_char, - fields_total: usize, - operators: u64, -} - /// Validate the ATC expression with the schema. /// /// # Arguments @@ -616,6 +606,13 @@ pub struct ExpressionValidationResult { /// /// - `atc` must be a valid pointer to a C-style string, must be properly aligned, /// and must not have '\0' in the middle. +/// - `schema` must be a valid pointer returned by [`schema_new`]. +/// - `fields_buf` must be a valid to write for `fields_len * size_of::()` bytes, +/// and it must be properly aligned. +/// - `fields_len` must be a valid to write for `size_of::()` bytes, +/// and it must be properly aligned. +/// - `fields_total` must be a valid to write for `size_of::()` bytes, +/// and it must be properly aligned. /// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, /// and it must be properly aligned. /// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, @@ -625,9 +622,13 @@ pub struct ExpressionValidationResult { pub unsafe extern "C" fn expression_validate( atc: *const u8, schema: &Schema, + fields_buf: *mut u8, + fields_len: *mut usize, + fields_total: *mut usize, + operators: *mut u64, errbuf: *mut u8, errbuf_len: *mut usize, -) -> *mut ExpressionValidationResult { +) -> bool { use std::collections::HashMap; use crate::ast::{BinaryOperatorFlags, Expression, LogicalExpression}; @@ -637,12 +638,6 @@ pub unsafe extern "C" fn expression_validate( let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); - let mut validation_result = Box::new(ExpressionValidationResult { - validate: false, - fields: std::ptr::null_mut(), - fields_total: 0, - operators: 0, - }); // Parse the expression let result = parse(atc).map_err(|e| e.to_string()); @@ -650,8 +645,7 @@ pub unsafe extern "C" fn expression_validate( let errlen = min(e.len(), *errbuf_len); errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); *errbuf_len = errlen; - validation_result.validate = false; - return Box::into_raw(validation_result); + return false; } // Unwrap is safe since we've already checked for error let ast = result.unwrap(); @@ -661,20 +655,30 @@ pub unsafe extern "C" fn expression_validate( let errlen = min(e.len(), *errbuf_len); errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); *errbuf_len = errlen; - validation_result.validate = false; - return Box::into_raw(validation_result); + return false; } // Get used fields let mut expr_fields = HashMap::new(); ast.add_to_counter(&mut expr_fields); let fields_count = expr_fields.len(); - let mut fields = Vec::<*const c_char>::with_capacity(fields_count); + let mut fields = Vec::with_capacity(fields_count); + let mut total_fields_length = 0; for k in expr_fields.into_keys() { - let ffi_string = ffi::CString::new(k).unwrap(); - let ptr = ffi_string.into_raw(); // Leak the CString - fields.push(ptr); + total_fields_length += k.as_bytes().len() + 1; // +1 for trailing \0 + fields.push(k); + } + + if *fields_len < total_fields_length { + let err_msg = format!( + "Fields buffer too small, provided {} bytes but required at least {} bytes.", + *fields_len, total_fields_length + ); + let errlen = min(err_msg.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&err_msg.as_bytes()[..errlen]); + *errbuf_len = errlen; + return false; } // Get used operators @@ -702,38 +706,21 @@ pub unsafe extern "C" fn expression_validate( } visit(&ast, &mut ops); - validation_result.validate = true; - validation_result.operators = ops.bits(); - let boxed_fields = fields.into_boxed_slice(); - let raw_boxed_fields = Box::into_raw(boxed_fields); // Leak the Box - - validation_result.fields = raw_boxed_fields.cast(); - validation_result.fields_total = raw_boxed_fields.len(); - - Box::into_raw(validation_result) // Leak the Box -} + // fulfill the output parameters, + let mut fields_buf_ptr = fields_buf; + for field in fields { + let field = ffi::CString::new(field).unwrap(); + let field_slice = field.as_bytes_with_nul(); + let field_len = field_slice.len(); + let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); + fields_buf.copy_from_slice(field_slice); + fields_buf_ptr = fields_buf_ptr.add(field_len); + } + *fields_total = fields_count; + *fields_len = total_fields_length; + *operators = ops.bits(); -/// Deallocate the ExpressionValidationResult object. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `result` must be a valid pointer returned by [`expression_validate`]. -#[cfg(feature = "expr_validation")] -#[no_mangle] -pub unsafe extern "C" fn expression_validate_free_result(result: *mut ExpressionValidationResult) { - let result = Box::from_raw(result); - let slice = std::slice::from_raw_parts_mut(result.fields, result.fields_total); - let boxed_fields = Box::from_raw(slice); - for ptr in boxed_fields.into_vec() { - let _ = ffi::CString::from_raw(ptr); // Drop the leaked CString - } // Drop the Box - drop(result); // Drop the Box + true } #[cfg(test)] @@ -802,17 +789,26 @@ mod tests { schema.add_field("net.src.ip", Type::IpAddr); schema.add_field("http.path", Type::String); + let mut fields_buf = vec![0u8; 1024]; + let mut fields_len = fields_buf.len(); + let mut fields_total = 0; + let mut operators = 0u64; + let result = expression_validate( atc.as_bytes().as_ptr(), &schema, + fields_buf.as_mut_ptr(), + &mut fields_len, + &mut fields_total, + &mut operators, errbuf.as_mut_ptr(), &mut errbuf_len, ); - assert!((*result).validate, "Validation failed"); - assert_eq!((*result).fields_total, 4, "Fields count mismatch"); + assert!(result, "Validation failed"); + assert_eq!(fields_total, 4, "Fields count mismatch"); assert_eq!( - (*result).operators, + operators, (BinaryOperatorFlags::EQUALS | BinaryOperatorFlags::REGEX | BinaryOperatorFlags::IN @@ -821,11 +817,13 @@ mod tests { .bits(), "Operators mismatch" ); - let mut fields = Vec::::with_capacity((*result).fields_total); - for i in 0..(*result).fields_total { - let field = (*result).fields.add(i); - let field = ffi::CStr::from_ptr(*field).to_str().unwrap(); - fields.push(field.to_string()); + let mut fields = Vec::::with_capacity(fields_total); + let mut p = 0; + for _ in 0..fields_total { + let field = ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()); + let len = field.to_bytes().len() + 1; + fields.push(field.to_string_lossy().to_string()); + p += len; } fields.sort(); assert_eq!( @@ -838,8 +836,6 @@ mod tests { ], "Fields mismatch" ); - - expression_validate_free_result(result); } } } From 31d89d357bbaee1c07e36c6cb118c4aed72b4181 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 25 Sep 2024 10:24:07 +0800 Subject: [PATCH 03/23] feat(expressions-compatibility) more flexible interface --- src/ffi.rs | 157 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 98 insertions(+), 59 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index 02e82fca..f5757efc 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -584,21 +584,35 @@ pub unsafe extern "C" fn context_get_result( /// /// - `atc`: the C-style string representing the ATC expression. /// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. +/// - `fields_buf`: a buffer to store the used fields. +/// - `fields_len`: a pointer to the length of the fields buffer. +/// - `fields_total`: a pointer for saving the total number of the fields. +/// - `operators`: a pointer for saving the used operators with bitflags. /// - `errbuf`: a buffer to store the error message. /// - `errbuf_len`: a pointer to the length of the error message buffer. /// /// # Returns /// -/// Returns a pointer of `ExpressionValidationResult`. -/// If the expression is not valid, the `validate` field will be `false`, -/// and the error message will be stored in the `errbuf`, -/// and the length of the error message will be stored in `errbuf_len`. +/// Returns the boolean value indicating the validation result. +/// +/// If `fields_buf` is null and `fields_len` or `fields_total` is non-null, it will write +/// the required buffer length and the total number of fields to the provided pointers. +/// If `fields_buf` is non-null, and `fields_len` is enough for the required buffer length, +/// it will write the used fields to the buffer separated by '\0' and the total number of fields +/// to the `fields_total`, and `fields_len` will be updated with the total buffer length. +/// If `fields_buf` is non-null, and `fields_len` is not enough for the required buffer length, +/// it will write the required buffer length to the `fields_len`, and the total number of fields +/// to the `fields_total`, then error message will be written to the `errbuf` with the length +/// updated to `errbuf_len` and return `false`. +/// If `operators` is non-null, it will write the used operators with bitflags to the provided pointer. +/// /// /// # Panics /// /// This function will panic when: /// /// - `atc` doesn't point to a valid C-style string. +/// - `fields_len` and `fields_total` are null when `fields_buf` is non-null. /// /// # Safety /// @@ -608,15 +622,20 @@ pub unsafe extern "C" fn context_get_result( /// and must not have '\0' in the middle. /// - `schema` must be a valid pointer returned by [`schema_new`]. /// - `fields_buf` must be a valid to write for `fields_len * size_of::()` bytes, -/// and it must be properly aligned. +/// and it must be properly aligned if non-null. /// - `fields_len` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned. +/// and it must be properly aligned if non-null. /// - `fields_total` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned. +/// and it must be properly aligned if non-null. +/// - `operators` must be a valid to write for `size_of::()` bytes, +/// and it must be properly aligned if non-null. /// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, /// and it must be properly aligned. /// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, /// and it must be properly aligned. +/// - If `fields_buf` is non-null, `fields_len` and `fields_total` must be non-null. +/// - If `fields_buf` is null, `fields_len` and `fields_total` can be non-null +/// for writing required buffer length and total number of fields. #[cfg(feature = "expr_validation")] #[no_mangle] pub unsafe extern "C" fn expression_validate( @@ -659,66 +678,85 @@ pub unsafe extern "C" fn expression_validate( } // Get used fields - let mut expr_fields = HashMap::new(); - ast.add_to_counter(&mut expr_fields); - let fields_count = expr_fields.len(); - let mut fields = Vec::with_capacity(fields_count); - let mut total_fields_length = 0; - - for k in expr_fields.into_keys() { - total_fields_length += k.as_bytes().len() + 1; // +1 for trailing \0 - fields.push(k); - } + if !(fields_buf.is_null() && fields_len.is_null() && fields_total.is_null()) { + if !fields_buf.is_null() { + assert!( + !fields_len.is_null() && !fields_total.is_null(), + "fields_len and fields_total must be non-null when fields_buf is non-null" + ); + } - if *fields_len < total_fields_length { - let err_msg = format!( - "Fields buffer too small, provided {} bytes but required at least {} bytes.", - *fields_len, total_fields_length - ); - let errlen = min(err_msg.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&err_msg.as_bytes()[..errlen]); - *errbuf_len = errlen; - return false; + let mut expr_fields: HashMap = HashMap::new(); + ast.add_to_counter(&mut expr_fields); + let fields_count = expr_fields.len(); + let mut fields = Vec::with_capacity(fields_count); + let mut total_fields_length = 0; + + for k in expr_fields.into_keys() { + total_fields_length += k.as_bytes().len() + 1; // +1 for trailing \0 + fields.push(k); + } + + if !fields_buf.is_null() { + if *fields_len < total_fields_length { + let err_msg = format!( + "Fields buffer too small, provided {} bytes but required at least {} bytes.", + *fields_len, total_fields_length + ); + let errlen = min(err_msg.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&err_msg.as_bytes()[..errlen]); + *errbuf_len = errlen; + *fields_len = total_fields_length; + *fields_total = fields_count; + return false; + } + + let mut fields_buf_ptr = fields_buf; + for field in fields { + let field = ffi::CString::new(field).unwrap(); + let field_slice = field.as_bytes_with_nul(); + let field_len = field_slice.len(); + let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); + fields_buf.copy_from_slice(field_slice); + fields_buf_ptr = fields_buf_ptr.add(field_len); + } + } + + if !fields_len.is_null() { + *fields_len = total_fields_length; + } + if !fields_total.is_null() { + *fields_total = fields_count; + } } // Get used operators - let mut ops = BinaryOperatorFlags::empty(); - fn visit(expr: &Expression, ops: &mut BinaryOperatorFlags) { - match expr { - Expression::Logical(logic_expression) => match logic_expression.as_ref() { - LogicalExpression::And(lhs, rhs) => { - visit(lhs, ops); - visit(rhs, ops); + if !operators.is_null() { + let mut ops = BinaryOperatorFlags::empty(); + fn visit(expr: &Expression, ops: &mut BinaryOperatorFlags) { + match expr { + Expression::Logical(logic_expression) => match logic_expression.as_ref() { + LogicalExpression::And(lhs, rhs) => { + visit(lhs, ops); + visit(rhs, ops); + } + LogicalExpression::Or(lhs, rhs) => { + visit(lhs, ops); + visit(rhs, ops); + } + LogicalExpression::Not(rhs) => { + visit(rhs, ops); + } + }, + Expression::Predicate(predict) => { + let op = BinaryOperatorFlags::from(&predict.op); + ops.insert(op); } - LogicalExpression::Or(lhs, rhs) => { - visit(lhs, ops); - visit(rhs, ops); - } - LogicalExpression::Not(rhs) => { - visit(rhs, ops); - } - }, - Expression::Predicate(predict) => { - let op = BinaryOperatorFlags::from(&predict.op); - ops.insert(op); } } + visit(&ast, &mut ops); + *operators = ops.bits(); } - visit(&ast, &mut ops); - - // fulfill the output parameters, - let mut fields_buf_ptr = fields_buf; - for field in fields { - let field = ffi::CString::new(field).unwrap(); - let field_slice = field.as_bytes_with_nul(); - let field_len = field_slice.len(); - let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); - fields_buf.copy_from_slice(field_slice); - fields_buf_ptr = fields_buf_ptr.add(field_len); - } - *fields_total = fields_count; - *fields_len = total_fields_length; - *operators = ops.bits(); true } @@ -807,6 +845,7 @@ mod tests { assert!(result, "Validation failed"); assert_eq!(fields_total, 4, "Fields count mismatch"); + assert_eq!(fields_len, 47, "Fields buffer length mismatch"); assert_eq!( operators, (BinaryOperatorFlags::EQUALS From 73479945c002b3ce2fef6b7e25d323a482248b97 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 25 Sep 2024 14:28:56 +0800 Subject: [PATCH 04/23] feat(expressions-compatibility) use int as returned value --- cbindgen.toml | 2 +- src/ast.rs | 13 +++++++++++++ src/ffi.rs | 28 +++++++++++++++++++--------- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/cbindgen.toml b/cbindgen.toml index bfc55038..d316207e 100644 --- a/cbindgen.toml +++ b/cbindgen.toml @@ -12,4 +12,4 @@ prefix_with_name = true bitflags = true [export] -include = ["BinaryOperatorFlags"] \ No newline at end of file +include = ["BinaryOperatorFlags", "EXPRESSION_VALIDATE_OK", "EXPRESSION_VALIDATE_FAILED","EXPRESSION_VALIDATE_BUF_TOO_SMALL"] \ No newline at end of file diff --git a/src/ast.rs b/src/ast.rs index 37257e13..262102c6 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -62,6 +62,19 @@ bitflags::bitflags! { const IN = 1 << 9; const NOT_IN = 1 << 10; const CONTAINS = 1 << 11; + + const UNUSED = !(Self::EQUALS.bits() + | Self::NOT_EQUALS.bits() + | Self::REGEX.bits() + | Self::PREFIX.bits() + | Self::POSTFIX.bits() + | Self::GREATER.bits() + | Self::GREATER_OR_EQUAL.bits() + | Self::LESS.bits() + | Self::LESS_OR_EQUAL.bits() + | Self::IN.bits() + | Self::NOT_IN.bits() + | Self::CONTAINS.bits()); } } diff --git a/src/ffi.rs b/src/ffi.rs index f5757efc..40cc8f87 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -578,6 +578,13 @@ pub unsafe extern "C" fn context_get_result( .unwrap() } +#[cfg(feature = "expr_validation")] +pub const EXPRESSION_VALIDATE_OK: i64 = 0; +#[cfg(feature = "expr_validation")] +pub const EXPRESSION_VALIDATE_FAILED: i64 = 1; +#[cfg(feature = "expr_validation")] +pub const EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; + /// Validate the ATC expression with the schema. /// /// # Arguments @@ -593,7 +600,10 @@ pub unsafe extern "C" fn context_get_result( /// /// # Returns /// -/// Returns the boolean value indicating the validation result. +/// Returns an integer value indicating the validation result: +/// - EXPRESSION_VALIDATE_OK(0) if validation is passed. +/// - EXPRESSION_VALIDATE_FAILED(1) if validation is failed. +/// - EXPRESSION_VALIDATE_BUF_TOO_SMALL(2) if the provided fields buffer is not enough. /// /// If `fields_buf` is null and `fields_len` or `fields_total` is non-null, it will write /// the required buffer length and the total number of fields to the provided pointers. @@ -602,9 +612,9 @@ pub unsafe extern "C" fn context_get_result( /// to the `fields_total`, and `fields_len` will be updated with the total buffer length. /// If `fields_buf` is non-null, and `fields_len` is not enough for the required buffer length, /// it will write the required buffer length to the `fields_len`, and the total number of fields -/// to the `fields_total`, then error message will be written to the `errbuf` with the length -/// updated to `errbuf_len` and return `false`. +/// to the `fields_total`, and return `2`. /// If `operators` is non-null, it will write the used operators with bitflags to the provided pointer. +/// The bitflags is defined by `BinaryOperatorFlags` and it must not contain any bits from `BinaryOperatorFlags::UNUSED`. /// /// /// # Panics @@ -647,7 +657,7 @@ pub unsafe extern "C" fn expression_validate( operators: *mut u64, errbuf: *mut u8, errbuf_len: *mut usize, -) -> bool { +) -> i64 { use std::collections::HashMap; use crate::ast::{BinaryOperatorFlags, Expression, LogicalExpression}; @@ -664,7 +674,7 @@ pub unsafe extern "C" fn expression_validate( let errlen = min(e.len(), *errbuf_len); errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); *errbuf_len = errlen; - return false; + return EXPRESSION_VALIDATE_FAILED; } // Unwrap is safe since we've already checked for error let ast = result.unwrap(); @@ -674,7 +684,7 @@ pub unsafe extern "C" fn expression_validate( let errlen = min(e.len(), *errbuf_len); errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); *errbuf_len = errlen; - return false; + return EXPRESSION_VALIDATE_FAILED; } // Get used fields @@ -708,7 +718,7 @@ pub unsafe extern "C" fn expression_validate( *errbuf_len = errlen; *fields_len = total_fields_length; *fields_total = fields_count; - return false; + return EXPRESSION_VALIDATE_BUF_TOO_SMALL; } let mut fields_buf_ptr = fields_buf; @@ -758,7 +768,7 @@ pub unsafe extern "C" fn expression_validate( *operators = ops.bits(); } - true + EXPRESSION_VALIDATE_OK } #[cfg(test)] @@ -843,7 +853,7 @@ mod tests { &mut errbuf_len, ); - assert!(result, "Validation failed"); + assert_eq!(result, EXPRESSION_VALIDATE_OK, "Validation failed"); assert_eq!(fields_total, 4, "Fields count mismatch"); assert_eq!(fields_len, 47, "Fields buffer length mismatch"); assert_eq!( From e70ffe4bbf5e97d21b05f8c2945c9d1336a4b99f Mon Sep 17 00:00:00 2001 From: Shiroko Date: Thu, 26 Sep 2024 14:28:53 +0800 Subject: [PATCH 05/23] feat(expressions-compatibility) correct constants namespace --- cbindgen.toml | 7 ++++++- src/ffi.rs | 25 ++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/cbindgen.toml b/cbindgen.toml index d316207e..4604ce2b 100644 --- a/cbindgen.toml +++ b/cbindgen.toml @@ -12,4 +12,9 @@ prefix_with_name = true bitflags = true [export] -include = ["BinaryOperatorFlags", "EXPRESSION_VALIDATE_OK", "EXPRESSION_VALIDATE_FAILED","EXPRESSION_VALIDATE_BUF_TOO_SMALL"] \ No newline at end of file +include = [ + "BinaryOperatorFlags", + "ATC_ROUTER_EXPRESSION_VALIDATE_OK", + "ATC_ROUTER_EXPRESSION_VALIDATE_FAILED", + "ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL" +] \ No newline at end of file diff --git a/src/ffi.rs b/src/ffi.rs index 40cc8f87..ecc9c888 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -579,11 +579,11 @@ pub unsafe extern "C" fn context_get_result( } #[cfg(feature = "expr_validation")] -pub const EXPRESSION_VALIDATE_OK: i64 = 0; +pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0; #[cfg(feature = "expr_validation")] -pub const EXPRESSION_VALIDATE_FAILED: i64 = 1; +pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1; #[cfg(feature = "expr_validation")] -pub const EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; +pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// Validate the ATC expression with the schema. /// @@ -601,9 +601,9 @@ pub const EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// # Returns /// /// Returns an integer value indicating the validation result: -/// - EXPRESSION_VALIDATE_OK(0) if validation is passed. -/// - EXPRESSION_VALIDATE_FAILED(1) if validation is failed. -/// - EXPRESSION_VALIDATE_BUF_TOO_SMALL(2) if the provided fields buffer is not enough. +/// - ATC_ROUTER_EXPRESSION_VALIDATE_OK(0) if validation is passed. +/// - ATC_ROUTER_EXPRESSION_VALIDATE_FAILED(1) if validation is failed. +/// - ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL(2) if the provided fields buffer is not enough. /// /// If `fields_buf` is null and `fields_len` or `fields_total` is non-null, it will write /// the required buffer length and the total number of fields to the provided pointers. @@ -674,7 +674,7 @@ pub unsafe extern "C" fn expression_validate( let errlen = min(e.len(), *errbuf_len); errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); *errbuf_len = errlen; - return EXPRESSION_VALIDATE_FAILED; + return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; } // Unwrap is safe since we've already checked for error let ast = result.unwrap(); @@ -684,7 +684,7 @@ pub unsafe extern "C" fn expression_validate( let errlen = min(e.len(), *errbuf_len); errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); *errbuf_len = errlen; - return EXPRESSION_VALIDATE_FAILED; + return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; } // Get used fields @@ -718,7 +718,7 @@ pub unsafe extern "C" fn expression_validate( *errbuf_len = errlen; *fields_len = total_fields_length; *fields_total = fields_count; - return EXPRESSION_VALIDATE_BUF_TOO_SMALL; + return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; } let mut fields_buf_ptr = fields_buf; @@ -768,7 +768,7 @@ pub unsafe extern "C" fn expression_validate( *operators = ops.bits(); } - EXPRESSION_VALIDATE_OK + ATC_ROUTER_EXPRESSION_VALIDATE_OK } #[cfg(test)] @@ -853,7 +853,10 @@ mod tests { &mut errbuf_len, ); - assert_eq!(result, EXPRESSION_VALIDATE_OK, "Validation failed"); + assert_eq!( + result, ATC_ROUTER_EXPRESSION_VALIDATE_OK, + "Validation failed" + ); assert_eq!(fields_total, 4, "Fields count mismatch"); assert_eq!(fields_len, 47, "Fields buffer length mismatch"); assert_eq!( From 8b01b596913eeb2703f4680cb6777c61ffa9614e Mon Sep 17 00:00:00 2001 From: Shiroko Date: Sun, 29 Sep 2024 10:48:15 +0800 Subject: [PATCH 06/23] feat(expressions-compatibility): disable this feature by default --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5aa16258..62b3db88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ regex = "1" serde = { version = "1.0", features = ["derive"], optional = true } serde_regex = { version = "1.1", optional = true } fnv = "1" -bitflags = "2.6.0" +bitflags = { version = "2.6.0", optional = true } [dev-dependencies] criterion = "0.*" @@ -30,9 +30,10 @@ criterion = "0.*" crate-type = ["lib", "cdylib", "staticlib"] [features] -default = ["ffi", "expr_validation"] +default = ["ffi"] ffi = [] serde = ["cidr/serde", "dep:serde", "dep:serde_regex"] +expr_validation = ["dep:bitflags"] [[bench]] name = "test" @@ -53,4 +54,3 @@ harness = false [[bench]] name = "build" harness = false -expr_validation = [] From 466ad92d1e17141730bddf5227b7343211c07e45 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Sun, 29 Sep 2024 18:37:04 +0800 Subject: [PATCH 07/23] feat(expressions-compatibility): add more test --- src/ffi.rs | 190 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 141 insertions(+), 49 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index ecc9c888..dcd4f446 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -822,27 +822,22 @@ mod tests { } #[cfg(feature = "expr_validation")] - #[test] - fn test_expression_validate() { - use crate::ast::BinaryOperatorFlags; - unsafe { - let mut schema = Schema::default(); - let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; - let atc = ffi::CString::new(atc).unwrap(); - let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; - let mut errbuf_len = ERR_BUF_MAX_LEN; - - schema.add_field("net.protocol", Type::String); - schema.add_field("net.dst.port", Type::Int); - schema.add_field("net.src.ip", Type::IpAddr); - schema.add_field("http.path", Type::String); - - let mut fields_buf = vec![0u8; 1024]; - let mut fields_len = fields_buf.len(); - let mut fields_total = 0; - let mut operators = 0u64; - - let result = expression_validate( + fn expr_validate_on( + schema: &Schema, + atc: &str, + fields_buf_size: usize, + ) -> Result<(Vec, u64), (i64, String)> { + let atc = ffi::CString::new(atc).unwrap(); + let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; + let mut errbuf_len = ERR_BUF_MAX_LEN; + + let mut fields_buf = vec![0u8; fields_buf_size]; + let mut fields_len = fields_buf.len(); + let mut fields_total = 0; + let mut operators = 0u64; + + let result = unsafe { + expression_validate( atc.as_bytes().as_ptr(), &schema, fields_buf.as_mut_ptr(), @@ -851,43 +846,140 @@ mod tests { &mut operators, errbuf.as_mut_ptr(), &mut errbuf_len, - ); + ) + }; - assert_eq!( - result, ATC_ROUTER_EXPRESSION_VALIDATE_OK, - "Validation failed" - ); - assert_eq!(fields_total, 4, "Fields count mismatch"); - assert_eq!(fields_len, 47, "Fields buffer length mismatch"); - assert_eq!( - operators, - (BinaryOperatorFlags::EQUALS - | BinaryOperatorFlags::REGEX - | BinaryOperatorFlags::IN - | BinaryOperatorFlags::NOT_IN - | BinaryOperatorFlags::CONTAINS) - .bits(), - "Operators mismatch" - ); + if result == ATC_ROUTER_EXPRESSION_VALIDATE_OK { let mut fields = Vec::::with_capacity(fields_total); let mut p = 0; for _ in 0..fields_total { - let field = ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()); + let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) }; let len = field.to_bytes().len() + 1; fields.push(field.to_string_lossy().to_string()); p += len; } + assert_eq!(fields_len, p, "Fields buffer length mismatch"); fields.sort(); - assert_eq!( - fields, - vec![ - "http.path".to_string(), - "net.dst.port".to_string(), - "net.protocol".to_string(), - "net.src.ip".to_string() - ], - "Fields mismatch" - ); + Ok((fields, operators)) + } else { + let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); + Err((result, err)) } } + + #[cfg(feature = "expr_validation")] + #[test] + fn test_expression_validate_success() { + use crate::ast::BinaryOperatorFlags; + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + schema.add_field("http.path", Type::String); + + let result = expr_validate_on(&schema, atc, 1024); + + assert!(result.is_ok(), "Validation failed"); + let (fields, ops) = result.unwrap(); // Unwrap is safe since we've already asserted it + assert_eq!( + ops, + (BinaryOperatorFlags::EQUALS + | BinaryOperatorFlags::REGEX + | BinaryOperatorFlags::IN + | BinaryOperatorFlags::NOT_IN + | BinaryOperatorFlags::CONTAINS) + .bits(), + "Operators mismatch" + ); + assert_eq!( + fields, + vec![ + "http.path".to_string(), + "net.dst.port".to_string(), + "net.protocol".to_string(), + "net.src.ip".to_string() + ], + "Fields mismatch" + ); + } + + #[cfg(feature = "expr_validation")] + #[test] + fn test_expression_validate_failed_parse() { + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + schema.add_field("http.path", Type::String); + + let result = expr_validate_on(&schema, atc, 1024); + + assert!(result.is_err(), "Validation unexcepted success"); + let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it + assert_eq!( + err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, + "Error code mismatch" + ); + assert_eq!( + err_message, + "In/NotIn operators only supports IP in CIDR".to_string(), + "Error message mismatch" + ); + } + + #[cfg(feature = "expr_validation")] + #[test] + fn test_expression_validate_failed_validate() { + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + + let result = expr_validate_on(&schema, atc, 1024); + + assert!(result.is_err(), "Validation unexcepted success"); + let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it + assert_eq!( + err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, + "Error code mismatch" + ); + assert_eq!( + err_message, + "Unknown LHS field".to_string(), + "Error message mismatch" + ); + } + + #[cfg(feature = "expr_validation")] + #[test] + fn test_expression_validate_buf_too_small() { + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + schema.add_field("http.path", Type::String); + + let result = expr_validate_on(&schema, atc, 10); + + assert!(result.is_err(), "Validation failed"); + let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it + assert_eq!( + err_code, ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL, + "Error code mismatch" + ); + assert_eq!( + err_message, + "Fields buffer too small, provided 10 bytes but required at least 47 bytes." + .to_string(), + "Error message mismatch" + ); + } } From 1cbe9b73065c88a3f878a27e9ed6e488561e31fd Mon Sep 17 00:00:00 2001 From: Shiroko Date: Mon, 30 Sep 2024 12:09:12 +0800 Subject: [PATCH 08/23] feat(expressions-compatibility): use a tarit to get metadata from expression --- src/ffi.rs | 42 ++++++++--------------- src/semantics.rs | 87 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 99 insertions(+), 30 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index dcd4f446..86823e06 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -658,12 +658,11 @@ pub unsafe extern "C" fn expression_validate( errbuf: *mut u8, errbuf_len: *mut usize, ) -> i64 { - use std::collections::HashMap; + use std::collections::HashSet; - use crate::ast::{BinaryOperatorFlags, Expression, LogicalExpression}; + use crate::ast::BinaryOperatorFlags; use crate::parser::parse; - use crate::semantics::FieldCounter; - use crate::semantics::Validate; + use crate::semantics::{GetPredicates, Validate}; let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); @@ -687,6 +686,9 @@ pub unsafe extern "C" fn expression_validate( return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; } + // Direct use GetPredicates trait to avoid unnecessary access + let predicates = ast.get_predicates(); + // Get used fields if !(fields_buf.is_null() && fields_len.is_null() && fields_total.is_null()) { if !fields_buf.is_null() { @@ -696,13 +698,16 @@ pub unsafe extern "C" fn expression_validate( ); } - let mut expr_fields: HashMap = HashMap::new(); - ast.add_to_counter(&mut expr_fields); + let mut expr_fields: HashSet = HashSet::new(); + for pred in &predicates { + expr_fields.insert(pred.lhs.var_name.clone()); + } + let fields_count = expr_fields.len(); let mut fields = Vec::with_capacity(fields_count); let mut total_fields_length = 0; - for k in expr_fields.into_keys() { + for k in expr_fields { total_fields_length += k.as_bytes().len() + 1; // +1 for trailing \0 fields.push(k); } @@ -743,28 +748,9 @@ pub unsafe extern "C" fn expression_validate( // Get used operators if !operators.is_null() { let mut ops = BinaryOperatorFlags::empty(); - fn visit(expr: &Expression, ops: &mut BinaryOperatorFlags) { - match expr { - Expression::Logical(logic_expression) => match logic_expression.as_ref() { - LogicalExpression::And(lhs, rhs) => { - visit(lhs, ops); - visit(rhs, ops); - } - LogicalExpression::Or(lhs, rhs) => { - visit(lhs, ops); - visit(rhs, ops); - } - LogicalExpression::Not(rhs) => { - visit(rhs, ops); - } - }, - Expression::Predicate(predict) => { - let op = BinaryOperatorFlags::from(&predict.op); - ops.insert(op); - } - } + for pred in &predicates { + ops |= BinaryOperatorFlags::from(&pred.op); } - visit(&ast, &mut ops); *operators = ops.bits(); } diff --git a/src/semantics.rs b/src/semantics.rs index 9c4dc5b8..6ee58992 100644 --- a/src/semantics.rs +++ b/src/semantics.rs @@ -1,6 +1,6 @@ -use crate::ast::{BinaryOperator, Expression, LogicalExpression, Type, Value}; +use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate, Type, Value}; use crate::schema::Schema; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; type ValidationResult = Result<(), String>; @@ -13,6 +13,21 @@ pub trait FieldCounter { fn remove_from_counter(&self, map: &mut HashMap); } +#[cfg(feature = "expr_validation")] +pub trait GetFields { + fn get_fields(&self) -> HashSet; +} + +#[cfg(feature = "expr_validation")] +pub trait GetOperators { + fn get_operators(&self) -> crate::ast::BinaryOperatorFlags; +} + +#[cfg(feature = "expr_validation")] +pub trait GetPredicates { + fn get_predicates(&self) -> Vec<&Predicate>; +} + impl Validate for Expression { fn validate(&self, schema: &Schema) -> ValidationResult { match self { @@ -110,6 +125,74 @@ impl Validate for Expression { } } +#[cfg(feature = "expr_validation")] +impl GetPredicates for Expression { + fn get_predicates(&self) -> Vec<&Predicate> { + let mut predicates = Vec::new(); + + fn visit<'a, 'b>(expr: &'a Expression, predicates: &mut Vec<&'b Predicate>) + where + 'a: 'b, + { + match expr { + Expression::Logical(l) => match l.as_ref() { + LogicalExpression::And(l, r) => { + visit(l, predicates); + visit(r, predicates); + } + LogicalExpression::Or(l, r) => { + visit(l, predicates); + visit(r, predicates); + } + LogicalExpression::Not(r) => { + visit(r, predicates); + } + }, + Expression::Predicate(p) => { + predicates.push(p); + } + } + } + + visit(self, &mut predicates); + + predicates + } +} + +#[cfg(feature = "expr_validation")] +impl GetFields for T +where + T: GetPredicates, +{ + fn get_fields(&self) -> HashSet { + let mut fields = HashSet::new(); + + for predicate in self.get_predicates() { + fields.insert(predicate.lhs.var_name.clone()); + } + + fields + } +} + +#[cfg(feature = "expr_validation")] +impl GetOperators for T +where + T: GetPredicates, +{ + fn get_operators(&self) -> crate::ast::BinaryOperatorFlags { + use crate::ast::BinaryOperatorFlags; + let mut ops = BinaryOperatorFlags::empty(); + + for predicate in self.get_predicates() { + ops |= BinaryOperatorFlags::from(&predicate.op); + } + + ops + } +} + #[cfg(test)] mod tests { use super::*; From 2edf23aeccec557463e5c8d1e9f12cf0861463f7 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Mon, 30 Sep 2024 15:07:36 +0800 Subject: [PATCH 09/23] feat(expressions-compatibility): remove unused trait and avoid unneceaasry clone --- src/ast.rs | 11 +++++++++++ src/ffi.rs | 33 ++++++++++++++------------------ src/semantics.rs | 50 +++++------------------------------------------- 3 files changed, 30 insertions(+), 64 deletions(-) diff --git a/src/ast.rs b/src/ast.rs index 262102c6..cc76b9e5 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -186,6 +186,17 @@ pub struct Predicate { pub op: BinaryOperator, } +#[cfg(feature = "expr_validation")] +impl Predicate { + pub fn get_field(&self) -> &str { + &self.lhs.var_name + } + + pub fn get_operator(&self) -> &BinaryOperator { + &self.op + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/ffi.rs b/src/ffi.rs index 86823e06..72c064a2 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -686,7 +686,7 @@ pub unsafe extern "C" fn expression_validate( return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; } - // Direct use GetPredicates trait to avoid unnecessary access + // Direct use GetPredicates trait to avoid unnecessary accesses let predicates = ast.get_predicates(); // Get used fields @@ -698,19 +698,14 @@ pub unsafe extern "C" fn expression_validate( ); } - let mut expr_fields: HashSet = HashSet::new(); - for pred in &predicates { - expr_fields.insert(pred.lhs.var_name.clone()); - } - - let fields_count = expr_fields.len(); - let mut fields = Vec::with_capacity(fields_count); - let mut total_fields_length = 0; - - for k in expr_fields { - total_fields_length += k.as_bytes().len() + 1; // +1 for trailing \0 - fields.push(k); - } + let expr_fields = predicates + .iter() + .map(|p| p.get_field()) + .collect::>(); + let total_fields_length = expr_fields + .iter() + .map(|k| k.as_bytes().len() + 1) + .sum::(); if !fields_buf.is_null() { if *fields_len < total_fields_length { @@ -722,13 +717,13 @@ pub unsafe extern "C" fn expression_validate( errbuf[..errlen].copy_from_slice(&err_msg.as_bytes()[..errlen]); *errbuf_len = errlen; *fields_len = total_fields_length; - *fields_total = fields_count; + *fields_total = expr_fields.len(); return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; } let mut fields_buf_ptr = fields_buf; - for field in fields { - let field = ffi::CString::new(field).unwrap(); + for field in &expr_fields { + let field = ffi::CString::new(*field).unwrap(); let field_slice = field.as_bytes_with_nul(); let field_len = field_slice.len(); let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); @@ -741,7 +736,7 @@ pub unsafe extern "C" fn expression_validate( *fields_len = total_fields_length; } if !fields_total.is_null() { - *fields_total = fields_count; + *fields_total = expr_fields.len(); } } @@ -749,7 +744,7 @@ pub unsafe extern "C" fn expression_validate( if !operators.is_null() { let mut ops = BinaryOperatorFlags::empty(); for pred in &predicates { - ops |= BinaryOperatorFlags::from(&pred.op); + ops |= BinaryOperatorFlags::from(pred.get_operator()); } *operators = ops.bits(); } diff --git a/src/semantics.rs b/src/semantics.rs index 6ee58992..a356fc3a 100644 --- a/src/semantics.rs +++ b/src/semantics.rs @@ -1,6 +1,9 @@ -use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate, Type, Value}; +use crate::ast::{BinaryOperator, Expression, LogicalExpression, Type, Value}; use crate::schema::Schema; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; + +#[cfg(feature = "expr_validation")] +use crate::ast::Predicate; type ValidationResult = Result<(), String>; @@ -13,16 +16,6 @@ pub trait FieldCounter { fn remove_from_counter(&self, map: &mut HashMap); } -#[cfg(feature = "expr_validation")] -pub trait GetFields { - fn get_fields(&self) -> HashSet; -} - -#[cfg(feature = "expr_validation")] -pub trait GetOperators { - fn get_operators(&self) -> crate::ast::BinaryOperatorFlags; -} - #[cfg(feature = "expr_validation")] pub trait GetPredicates { fn get_predicates(&self) -> Vec<&Predicate>; @@ -160,39 +153,6 @@ impl GetPredicates for Expression { } } -#[cfg(feature = "expr_validation")] -impl GetFields for T -where - T: GetPredicates, -{ - fn get_fields(&self) -> HashSet { - let mut fields = HashSet::new(); - - for predicate in self.get_predicates() { - fields.insert(predicate.lhs.var_name.clone()); - } - - fields - } -} - -#[cfg(feature = "expr_validation")] -impl GetOperators for T -where - T: GetPredicates, -{ - fn get_operators(&self) -> crate::ast::BinaryOperatorFlags { - use crate::ast::BinaryOperatorFlags; - let mut ops = BinaryOperatorFlags::empty(); - - for predicate in self.get_predicates() { - ops |= BinaryOperatorFlags::from(&predicate.op); - } - - ops - } -} - #[cfg(test)] mod tests { use super::*; From b69f26fb6448f90e7625ea37d41b5d23121c3c3d Mon Sep 17 00:00:00 2001 From: Haoxuan Date: Wed, 9 Oct 2024 10:37:28 +0800 Subject: [PATCH 10/23] feat(expressions-compatibility): update interface doc for clarity. Co-authored-by: Javier --- src/ffi.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ffi.rs b/src/ffi.rs index 72c064a2..16af1e0a 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -608,7 +608,7 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// If `fields_buf` is null and `fields_len` or `fields_total` is non-null, it will write /// the required buffer length and the total number of fields to the provided pointers. /// If `fields_buf` is non-null, and `fields_len` is enough for the required buffer length, -/// it will write the used fields to the buffer separated by '\0' and the total number of fields +/// it will write the used fields to the buffer, each terminated by '\0' and the total number of fields /// to the `fields_total`, and `fields_len` will be updated with the total buffer length. /// If `fields_buf` is non-null, and `fields_len` is not enough for the required buffer length, /// it will write the required buffer length to the `fields_len`, and the total number of fields From c26656cee3e498769b3de8292f1b4edae5705c2f Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 30 Oct 2024 16:28:48 +0800 Subject: [PATCH 11/23] feat(expression-compatibility): remove feature flags and use `ffi` instead. This commit contains two parts: 1. Moving everything guarded by feature flag `expr_validation` to `ffi.rs`. 2. Remove feature flag `expr_validation`. 3. Remove public interfaces `get_field` and `get_operator` for `Predicate`. No internal logic has been changed. --- Cargo.toml | 3 +- cbindgen.toml | 1 - src/ast.rs | 64 ----------------------------- src/ffi.rs | 105 ++++++++++++++++++++++++++++++++++++++++------- src/semantics.rs | 43 ------------------- 5 files changed, 91 insertions(+), 125 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 62b3db88..6807824e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,9 +31,8 @@ crate-type = ["lib", "cdylib", "staticlib"] [features] default = ["ffi"] -ffi = [] +ffi = ["dep:bitflags"] serde = ["cidr/serde", "dep:serde", "dep:serde_regex"] -expr_validation = ["dep:bitflags"] [[bench]] name = "test" diff --git a/cbindgen.toml b/cbindgen.toml index 4604ce2b..87f8e136 100644 --- a/cbindgen.toml +++ b/cbindgen.toml @@ -6,7 +6,6 @@ prefix_with_name = true [defines] "feature = ffi" = "DEFINE_ATC_ROUTER_FFI" -"feature = expr_validation" = "DEFINE_ATC_ROUTER_EXPR_VALIDATION" [macro_expansion] bitflags = true diff --git a/src/ast.rs b/src/ast.rs index cc76b9e5..42a04752 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -45,59 +45,6 @@ pub enum BinaryOperator { Contains, // contains } -#[cfg(feature = "expr_validation")] -bitflags::bitflags! { - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] - #[repr(C)] - pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ { - const EQUALS = 1 << 0; - const NOT_EQUALS = 1 << 1; - const REGEX = 1 << 2; - const PREFIX = 1 << 3; - const POSTFIX = 1 << 4; - const GREATER = 1 << 5; - const GREATER_OR_EQUAL = 1 << 6; - const LESS = 1 << 7; - const LESS_OR_EQUAL = 1 << 8; - const IN = 1 << 9; - const NOT_IN = 1 << 10; - const CONTAINS = 1 << 11; - - const UNUSED = !(Self::EQUALS.bits() - | Self::NOT_EQUALS.bits() - | Self::REGEX.bits() - | Self::PREFIX.bits() - | Self::POSTFIX.bits() - | Self::GREATER.bits() - | Self::GREATER_OR_EQUAL.bits() - | Self::LESS.bits() - | Self::LESS_OR_EQUAL.bits() - | Self::IN.bits() - | Self::NOT_IN.bits() - | Self::CONTAINS.bits()); - } -} - -#[cfg(feature = "expr_validation")] -impl From<&BinaryOperator> for BinaryOperatorFlags { - fn from(op: &BinaryOperator) -> Self { - match op { - BinaryOperator::Equals => Self::EQUALS, - BinaryOperator::NotEquals => Self::NOT_EQUALS, - BinaryOperator::Regex => Self::REGEX, - BinaryOperator::Prefix => Self::PREFIX, - BinaryOperator::Postfix => Self::POSTFIX, - BinaryOperator::Greater => Self::GREATER, - BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL, - BinaryOperator::Less => Self::LESS, - BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL, - BinaryOperator::In => Self::IN, - BinaryOperator::NotIn => Self::NOT_IN, - BinaryOperator::Contains => Self::CONTAINS, - } - } -} - #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub enum Value { @@ -186,17 +133,6 @@ pub struct Predicate { pub op: BinaryOperator, } -#[cfg(feature = "expr_validation")] -impl Predicate { - pub fn get_field(&self) -> &str { - &self.lhs.var_name - } - - pub fn get_operator(&self) -> &BinaryOperator { - &self.op - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/ffi.rs b/src/ffi.rs index 16af1e0a..43d11b46 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -1,7 +1,8 @@ -use crate::ast::{Type, Value}; +use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate, Type, Value}; use crate::context::Context; use crate::router::Router; use crate::schema::Schema; +use bitflags::bitflags; use cidr::IpCidr; use std::cmp::min; use std::convert::TryFrom; @@ -578,11 +579,93 @@ pub unsafe extern "C" fn context_get_result( .unwrap() } -#[cfg(feature = "expr_validation")] +impl Expression { + fn get_predicates(&self) -> Vec<&Predicate> { + let mut predicates = Vec::new(); + + fn visit<'a, 'b>(expr: &'a Expression, predicates: &mut Vec<&'b Predicate>) + where + 'a: 'b, + { + match expr { + Expression::Logical(l) => match l.as_ref() { + LogicalExpression::And(l, r) => { + visit(l, predicates); + visit(r, predicates); + } + LogicalExpression::Or(l, r) => { + visit(l, predicates); + visit(r, predicates); + } + LogicalExpression::Not(r) => { + visit(r, predicates); + } + }, + Expression::Predicate(p) => { + predicates.push(p); + } + } + } + + visit(self, &mut predicates); + + predicates + } +} + +bitflags! { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + #[repr(C)] + pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ { + const EQUALS = 1 << 0; + const NOT_EQUALS = 1 << 1; + const REGEX = 1 << 2; + const PREFIX = 1 << 3; + const POSTFIX = 1 << 4; + const GREATER = 1 << 5; + const GREATER_OR_EQUAL = 1 << 6; + const LESS = 1 << 7; + const LESS_OR_EQUAL = 1 << 8; + const IN = 1 << 9; + const NOT_IN = 1 << 10; + const CONTAINS = 1 << 11; + + const UNUSED = !(Self::EQUALS.bits() + | Self::NOT_EQUALS.bits() + | Self::REGEX.bits() + | Self::PREFIX.bits() + | Self::POSTFIX.bits() + | Self::GREATER.bits() + | Self::GREATER_OR_EQUAL.bits() + | Self::LESS.bits() + | Self::LESS_OR_EQUAL.bits() + | Self::IN.bits() + | Self::NOT_IN.bits() + | Self::CONTAINS.bits()); + } +} + +impl From<&BinaryOperator> for BinaryOperatorFlags { + fn from(op: &BinaryOperator) -> Self { + match op { + BinaryOperator::Equals => Self::EQUALS, + BinaryOperator::NotEquals => Self::NOT_EQUALS, + BinaryOperator::Regex => Self::REGEX, + BinaryOperator::Prefix => Self::PREFIX, + BinaryOperator::Postfix => Self::POSTFIX, + BinaryOperator::Greater => Self::GREATER, + BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL, + BinaryOperator::Less => Self::LESS, + BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL, + BinaryOperator::In => Self::IN, + BinaryOperator::NotIn => Self::NOT_IN, + BinaryOperator::Contains => Self::CONTAINS, + } + } +} + pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0; -#[cfg(feature = "expr_validation")] pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1; -#[cfg(feature = "expr_validation")] pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// Validate the ATC expression with the schema. @@ -646,7 +729,6 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// - If `fields_buf` is non-null, `fields_len` and `fields_total` must be non-null. /// - If `fields_buf` is null, `fields_len` and `fields_total` can be non-null /// for writing required buffer length and total number of fields. -#[cfg(feature = "expr_validation")] #[no_mangle] pub unsafe extern "C" fn expression_validate( atc: *const u8, @@ -660,9 +742,8 @@ pub unsafe extern "C" fn expression_validate( ) -> i64 { use std::collections::HashSet; - use crate::ast::BinaryOperatorFlags; use crate::parser::parse; - use crate::semantics::{GetPredicates, Validate}; + use crate::semantics::Validate; let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); @@ -700,7 +781,7 @@ pub unsafe extern "C" fn expression_validate( let expr_fields = predicates .iter() - .map(|p| p.get_field()) + .map(|p| p.lhs.var_name.as_str()) .collect::>(); let total_fields_length = expr_fields .iter() @@ -744,7 +825,7 @@ pub unsafe extern "C" fn expression_validate( if !operators.is_null() { let mut ops = BinaryOperatorFlags::empty(); for pred in &predicates { - ops |= BinaryOperatorFlags::from(pred.get_operator()); + ops |= BinaryOperatorFlags::from(&pred.op); } *operators = ops.bits(); } @@ -802,7 +883,6 @@ mod tests { } } - #[cfg(feature = "expr_validation")] fn expr_validate_on( schema: &Schema, atc: &str, @@ -848,10 +928,8 @@ mod tests { } } - #[cfg(feature = "expr_validation")] #[test] fn test_expression_validate_success() { - use crate::ast::BinaryOperatorFlags; let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; let mut schema = Schema::default(); @@ -886,7 +964,6 @@ mod tests { ); } - #[cfg(feature = "expr_validation")] #[test] fn test_expression_validate_failed_parse() { let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0) && http.path contains "hello""##; @@ -912,7 +989,6 @@ mod tests { ); } - #[cfg(feature = "expr_validation")] #[test] fn test_expression_validate_failed_validate() { let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; @@ -937,7 +1013,6 @@ mod tests { ); } - #[cfg(feature = "expr_validation")] #[test] fn test_expression_validate_buf_too_small() { let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; diff --git a/src/semantics.rs b/src/semantics.rs index a356fc3a..9c4dc5b8 100644 --- a/src/semantics.rs +++ b/src/semantics.rs @@ -2,9 +2,6 @@ use crate::ast::{BinaryOperator, Expression, LogicalExpression, Type, Value}; use crate::schema::Schema; use std::collections::HashMap; -#[cfg(feature = "expr_validation")] -use crate::ast::Predicate; - type ValidationResult = Result<(), String>; pub trait Validate { @@ -16,11 +13,6 @@ pub trait FieldCounter { fn remove_from_counter(&self, map: &mut HashMap); } -#[cfg(feature = "expr_validation")] -pub trait GetPredicates { - fn get_predicates(&self) -> Vec<&Predicate>; -} - impl Validate for Expression { fn validate(&self, schema: &Schema) -> ValidationResult { match self { @@ -118,41 +110,6 @@ impl Validate for Expression { } } -#[cfg(feature = "expr_validation")] -impl GetPredicates for Expression { - fn get_predicates(&self) -> Vec<&Predicate> { - let mut predicates = Vec::new(); - - fn visit<'a, 'b>(expr: &'a Expression, predicates: &mut Vec<&'b Predicate>) - where - 'a: 'b, - { - match expr { - Expression::Logical(l) => match l.as_ref() { - LogicalExpression::And(l, r) => { - visit(l, predicates); - visit(r, predicates); - } - LogicalExpression::Or(l, r) => { - visit(l, predicates); - visit(r, predicates); - } - LogicalExpression::Not(r) => { - visit(r, predicates); - } - }, - Expression::Predicate(p) => { - predicates.push(p); - } - } - } - - visit(self, &mut predicates); - - predicates - } -} - #[cfg(test)] mod tests { use super::*; From 0ffb54b90096798b8f4a8b9629afddda4bbd1ace Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 30 Oct 2024 16:54:06 +0800 Subject: [PATCH 12/23] feat(expression-compatibility): remove some unnecessary code. remove the redundant error message when buffer is too small. remove the ability to fetch fields count and total length without a valid `field_buf`. Must pass a valid `field_buf` if calling for getting fields. --- src/ffi.rs | 77 ++++++++++++++++++++++-------------------------------- 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/src/ffi.rs b/src/ffi.rs index 43d11b46..3aabf096 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -685,17 +685,17 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// /// Returns an integer value indicating the validation result: /// - ATC_ROUTER_EXPRESSION_VALIDATE_OK(0) if validation is passed. -/// - ATC_ROUTER_EXPRESSION_VALIDATE_FAILED(1) if validation is failed. +/// - ATC_ROUTER_EXPRESSION_VALIDATE_FAILED(1) if validation is failed. The `errbuf` and `errbuf_len` will be updated with the error message. /// - ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL(2) if the provided fields buffer is not enough. /// -/// If `fields_buf` is null and `fields_len` or `fields_total` is non-null, it will write -/// the required buffer length and the total number of fields to the provided pointers. /// If `fields_buf` is non-null, and `fields_len` is enough for the required buffer length, /// it will write the used fields to the buffer, each terminated by '\0' and the total number of fields /// to the `fields_total`, and `fields_len` will be updated with the total buffer length. +/// /// If `fields_buf` is non-null, and `fields_len` is not enough for the required buffer length, /// it will write the required buffer length to the `fields_len`, and the total number of fields -/// to the `fields_total`, and return `2`. +/// to the `fields_total`, and return `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. +/// /// If `operators` is non-null, it will write the used operators with bitflags to the provided pointer. /// The bitflags is defined by `BinaryOperatorFlags` and it must not contain any bits from `BinaryOperatorFlags::UNUSED`. /// @@ -727,7 +727,6 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, /// and it must be properly aligned. /// - If `fields_buf` is non-null, `fields_len` and `fields_total` must be non-null. -/// - If `fields_buf` is null, `fields_len` and `fields_total` can be non-null /// for writing required buffer length and total number of fields. #[no_mangle] pub unsafe extern "C" fn expression_validate( @@ -771,13 +770,11 @@ pub unsafe extern "C" fn expression_validate( let predicates = ast.get_predicates(); // Get used fields - if !(fields_buf.is_null() && fields_len.is_null() && fields_total.is_null()) { - if !fields_buf.is_null() { - assert!( - !fields_len.is_null() && !fields_total.is_null(), - "fields_len and fields_total must be non-null when fields_buf is non-null" - ); - } + if !fields_buf.is_null() { + assert!( + !(fields_len.is_null() || fields_total.is_null()), + "fields_len and fields_total must be non-null when fields_buf is non-null" + ); let expr_fields = predicates .iter() @@ -790,13 +787,6 @@ pub unsafe extern "C" fn expression_validate( if !fields_buf.is_null() { if *fields_len < total_fields_length { - let err_msg = format!( - "Fields buffer too small, provided {} bytes but required at least {} bytes.", - *fields_len, total_fields_length - ); - let errlen = min(err_msg.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&err_msg.as_bytes()[..errlen]); - *errbuf_len = errlen; *fields_len = total_fields_length; *fields_total = expr_fields.len(); return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; @@ -813,12 +803,8 @@ pub unsafe extern "C" fn expression_validate( } } - if !fields_len.is_null() { - *fields_len = total_fields_length; - } - if !fields_total.is_null() { - *fields_total = expr_fields.len(); - } + *fields_len = total_fields_length; + *fields_total = expr_fields.len(); } // Get used operators @@ -910,21 +896,26 @@ mod tests { ) }; - if result == ATC_ROUTER_EXPRESSION_VALIDATE_OK { - let mut fields = Vec::::with_capacity(fields_total); - let mut p = 0; - for _ in 0..fields_total { - let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) }; - let len = field.to_bytes().len() + 1; - fields.push(field.to_string_lossy().to_string()); - p += len; + match result { + ATC_ROUTER_EXPRESSION_VALIDATE_OK => { + let mut fields = Vec::::with_capacity(fields_total); + let mut p = 0; + for _ in 0..fields_total { + let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) }; + let len = field.to_bytes().len() + 1; + fields.push(field.to_string_lossy().to_string()); + p += len; + } + assert_eq!(fields_len, p, "Fields buffer length mismatch"); + fields.sort(); + Ok((fields, operators)) + } + ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => { + let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); + Err((result, err)) } - assert_eq!(fields_len, p, "Fields buffer length mismatch"); - fields.sort(); - Ok((fields, operators)) - } else { - let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); - Err((result, err)) + ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())), + _ => panic!("Unknown error code"), } } @@ -1026,16 +1017,10 @@ mod tests { let result = expr_validate_on(&schema, atc, 10); assert!(result.is_err(), "Validation failed"); - let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it + let (err_code, _) = result.unwrap_err(); // Unwrap is safe since we've already asserted it assert_eq!( err_code, ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL, "Error code mismatch" ); - assert_eq!( - err_message, - "Fields buffer too small, provided 10 bytes but required at least 47 bytes." - .to_string(), - "Error message mismatch" - ); } } From 56d2aecc2c63540bfc27289fbad361ccadb17d2a Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 30 Oct 2024 17:07:03 +0800 Subject: [PATCH 13/23] feat(expression-compatibility): break `ffi.rs` into module. --- src/ffi.rs | 1026 ----------------------------------------- src/ffi/context.rs | 251 ++++++++++ src/ffi/expression.rs | 409 ++++++++++++++++ src/ffi/mod.rs | 58 +++ src/ffi/router.rs | 288 ++++++++++++ src/ffi/schema.rs | 54 +++ 6 files changed, 1060 insertions(+), 1026 deletions(-) delete mode 100644 src/ffi.rs create mode 100644 src/ffi/context.rs create mode 100644 src/ffi/expression.rs create mode 100644 src/ffi/mod.rs create mode 100644 src/ffi/router.rs create mode 100644 src/ffi/schema.rs diff --git a/src/ffi.rs b/src/ffi.rs deleted file mode 100644 index 3aabf096..00000000 --- a/src/ffi.rs +++ /dev/null @@ -1,1026 +0,0 @@ -use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate, Type, Value}; -use crate::context::Context; -use crate::router::Router; -use crate::schema::Schema; -use bitflags::bitflags; -use cidr::IpCidr; -use std::cmp::min; -use std::convert::TryFrom; -use std::ffi; -use std::net::IpAddr; -use std::os::raw::c_char; -use std::slice::{from_raw_parts, from_raw_parts_mut}; -use uuid::fmt::Hyphenated; -use uuid::Uuid; - -pub const ERR_BUF_MAX_LEN: usize = 4096; - -#[derive(Debug)] -#[repr(C)] -pub enum CValue { - Str(*const u8, usize), - IpCidr(*const u8), - IpAddr(*const u8), - Int(i64), -} - -impl TryFrom<&CValue> for Value { - type Error = String; - - fn try_from(v: &CValue) -> Result { - Ok(match v { - CValue::Str(s, len) => Self::String(unsafe { - std::str::from_utf8(from_raw_parts(*s, *len)) - .map_err(|e| e.to_string())? - .to_string() - }), - CValue::IpCidr(s) => Self::IpCidr( - unsafe { - ffi::CStr::from_ptr(*s as *const c_char) - .to_str() - .map_err(|e| e.to_string())? - .to_string() - } - .parse::() - .map_err(|e| e.to_string())?, - ), - CValue::IpAddr(s) => Self::IpAddr( - unsafe { - ffi::CStr::from_ptr(*s as *const c_char) - .to_str() - .map_err(|e| e.to_string())? - .to_string() - } - .parse::() - .map_err(|e| e.to_string())?, - ), - CValue::Int(i) => Self::Int(*i), - }) - } -} - -#[no_mangle] -pub extern "C" fn schema_new() -> *mut Schema { - Box::into_raw(Box::default()) -} - -/// Deallocate the schema object. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `schema` must be a valid pointer returned by [`schema_new`]. -#[no_mangle] -pub unsafe extern "C" fn schema_free(schema: *mut Schema) { - drop(Box::from_raw(schema)); -} - -/// Add a new field with the specified type to the schema. -/// -/// # Arguments -/// -/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. -/// - `field`: the C-style string representing the field name. -/// - `typ`: the type of the field. -/// -/// # Panics -/// -/// This function will panic if the C-style string -/// pointed by `field` is not a valid UTF-8 string. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `schema` must be a valid pointer returned by [`schema_new`]. -/// - `field` must be a valid pointer to a C-style string, must be properly aligned, -/// and must not have '\0' in the middle. -#[no_mangle] -pub unsafe extern "C" fn schema_add_field(schema: &mut Schema, field: *const i8, typ: Type) { - let field = ffi::CStr::from_ptr(field as *const c_char) - .to_str() - .unwrap(); - - schema.add_field(field, typ) -} - -/// Create a new router object associated with the schema. -/// -/// # Arguments -/// -/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `schema` must be a valid pointer returned by [`schema_new`]. -#[no_mangle] -pub unsafe extern "C" fn router_new(schema: &Schema) -> *mut Router { - Box::into_raw(Box::new(Router::new(schema))) -} - -/// Deallocate the router object. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `router` must be a valid pointer returned by [`router_new`]. -#[no_mangle] -pub unsafe extern "C" fn router_free(router: *mut Router) { - drop(Box::from_raw(router)); -} - -/// Add a new matcher to the router. -/// -/// # Arguments -/// -/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. -/// - `priority`: the priority of the matcher, higher value means higher priority, -/// and the matcher with the highest priority will be executed first. -/// - `uuid`: the C-style string representing the UUID of the matcher. -/// - `atc`: the C-style string representing the ATC expression. -/// - `errbuf`: a buffer to store the error message. -/// - `errbuf_len`: a pointer to the length of the error message buffer. -/// -/// # Returns -/// -/// Returns `true` if the matcher was added successfully, otherwise `false`, -/// and the error message will be stored in the `errbuf`, -/// and the length of the error message will be stored in `errbuf_len`. -/// -/// # Errors -/// -/// This function will return `false` if the matcher could not be added to the router, -/// such as duplicate UUID, and invalid ATC expression. -/// -/// # Panics -/// -/// This function will panic when: -/// -/// - `uuid` doesn't point to a ASCII sequence representing a valid 128-bit UUID. -/// - `atc` doesn't point to a valid C-style string. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `router` must be a valid pointer returned by [`router_new`]. -/// - `uuid` must be a valid pointer to a C-style string, must be properly aligned, -/// and must not have '\0' in the middle. -/// - `atc` must be a valid pointer to a C-style string, must be properly aligned, -/// and must not have '\0' in the middle. -/// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, -/// and it must be properly aligned. -/// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, -/// and it must be properly aligned. -#[no_mangle] -pub unsafe extern "C" fn router_add_matcher( - router: &mut Router, - priority: usize, - uuid: *const i8, - atc: *const i8, - errbuf: *mut u8, - errbuf_len: *mut usize, -) -> bool { - let uuid = ffi::CStr::from_ptr(uuid as *const c_char).to_str().unwrap(); - let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); - let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); - - let uuid = Uuid::try_parse(uuid).expect("invalid UUID format"); - - if let Err(e) = router.add_matcher(priority, uuid, atc) { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; - return false; - } - - true -} - -/// Remove a matcher from the router. -/// -/// # Arguments -/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. -/// - `priority`: the priority of the matcher to be removed. -/// - `uuid`: the C-style string representing the UUID of the matcher to be removed. -/// -/// # Returns -/// -/// Returns `true` if the matcher was removed successfully, otherwise `false`, -/// such as when the matcher with the specified UUID doesn't exist or -/// the priority doesn't match the UUID. -/// -/// # Panics -/// -/// This function will panic when `uuid` doesn't point to a ASCII sequence -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `router` must be a valid pointer returned by [`router_new`]. -/// - `uuid` must be a valid pointer to a C-style string, must be properly aligned, -/// and must not have '\0' in the middle. -#[no_mangle] -pub unsafe extern "C" fn router_remove_matcher( - router: &mut Router, - priority: usize, - uuid: *const i8, -) -> bool { - let uuid = ffi::CStr::from_ptr(uuid as *const c_char).to_str().unwrap(); - let uuid = Uuid::try_parse(uuid).expect("invalid UUID format"); - - router.remove_matcher(priority, uuid) -} - -/// Execute the router with the context. -/// -/// # Arguments -/// -/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. -/// - `context`: a pointer to the [`Context`] object. -/// -/// # Returns -/// -/// Returns `true` if found a match, `false` means no match found. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `router` must be a valid pointer returned by [`router_new`]. -/// - `context` must be a valid pointer returned by [`context_new`], -/// and must be reset by [`context_reset`] before calling this function -/// if you want to reuse the same context for multiple matches. -#[no_mangle] -pub unsafe extern "C" fn router_execute(router: &Router, context: &mut Context) -> bool { - router.execute(context) -} - -/// Get the de-duplicated fields that are actually used in the router. -/// This is useful when you want to know what fields are actually used in the router, -/// so you can generate their values on-demand. -/// -/// # Arguments -/// -/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. -/// - `fields`: a pointer to an array of pointers to the field names -/// (NOT C-style strings) that are actually used in the router, which will be filled in. -/// if `fields` is `NULL`, this function will only return the number of fields used -/// in the router. -/// - `fields_len`: a pointer to an array of the length of each field name. -/// -/// # Lifetimes -/// -/// The string pointers stored in `fields` might be invalidated if any of the following -/// operations are happened: -/// -/// - The `router` was deallocated. -/// - A new matcher was added to the `router`. -/// - A matcher was removed from the `router`. -/// -/// # Returns -/// -/// Returns the number of fields that are actually used in the router. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `router` must be a valid pointer returned by [`router_new`]. -/// - If `fields` is not `NULL`, `fields` must be valid to read and write for -/// `fields_len * size_of::<*const u8>()` bytes, and it must be properly aligned. -/// - If `fields` is not `NULL`, `fields_len` must be valid to read and write for -/// `size_of::()` bytes, and it must be properly aligned. -/// - DO NOT write the memory pointed by the elements of `fields`. -/// - DO NOT access the memory pointed by the elements of `fields` -/// after it becomes invalid, see the `Lifetimes` section. -#[no_mangle] -pub unsafe extern "C" fn router_get_fields( - router: &Router, - fields: *mut *const u8, - fields_len: *mut usize, -) -> usize { - if !fields.is_null() { - assert!(!fields_len.is_null()); - assert!(*fields_len >= router.fields.len()); - - let fields = from_raw_parts_mut(fields, *fields_len); - let fields_len = from_raw_parts_mut(fields_len, *fields_len); - - for (i, k) in router.fields.keys().enumerate() { - fields[i] = k.as_bytes().as_ptr(); - fields_len[i] = k.len() - } - } - - router.fields.len() -} - -/// Allocate a new context object associated with the schema. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `schema` must be a valid pointer returned by [`schema_new`]. -#[no_mangle] -pub unsafe extern "C" fn context_new(schema: &Schema) -> *mut Context { - Box::into_raw(Box::new(Context::new(schema))) -} - -/// Deallocate the context object. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `context` must be a valid pointer returned by [`context_new`]. -#[no_mangle] -pub unsafe extern "C" fn context_free(context: *mut Context) { - drop(Box::from_raw(context)); -} - -/// Add a value associated with a field to the context. -/// This is useful when you want to match a value against a field in the schema. -/// -/// # Arguments -/// -/// - `context`: a pointer to the [`Context`] object. -/// - `field`: the C-style string representing the field name. -/// - `value`: the value to be added to the context. -/// - `errbuf`: a buffer to store the error message. -/// - `errbuf_len`: a pointer to the length of the error message buffer. -/// -/// # Returns -/// -/// Returns `true` if the value was added successfully, otherwise `false`, -/// and the error message will be stored in the `errbuf`, -/// and the length of the error message will be stored in `errbuf_len`. -/// -/// # Errors -/// -/// This function will return `false` if the value could not be added to the context, -/// such as when a String value is not a valid UTF-8 string. -/// -/// # Panics -/// -/// This function will panic if the provided value does not match the schema. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// * `context` must be a valid pointer returned by [`context_new`]. -/// * `field` must be a valid pointer to a C-style string, -/// must be properply aligned, and must not have '\0' in the middle. -/// * `value` must be a valid pointer to a [`CValue`]. -/// * `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, -/// and it must be properly aligned. -/// * `errbuf_len` must be vlaid to read and write for `size_of::()` bytes, -/// and it must be properly aligned. -#[no_mangle] -pub unsafe extern "C" fn context_add_value( - context: &mut Context, - field: *const i8, - value: &CValue, - errbuf: *mut u8, - errbuf_len: *mut usize, -) -> bool { - let field = ffi::CStr::from_ptr(field as *const c_char) - .to_str() - .unwrap(); - let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); - - let value: Result = value.try_into(); - if let Err(e) = value { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; - return false; - } - - context.add_value(field, value.unwrap()); - - true -} - -/// Reset the context so that it can be reused. -/// This is useful when you want to reuse the same context for multiple matches. -/// This will clear all the values that were added to the context, -/// but keep the memory allocated for the context. -/// -/// # Errors -/// -/// This function never fails. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `context` must be a valid pointer returned by [`context_new`]. -#[no_mangle] -pub unsafe extern "C" fn context_reset(context: &mut Context) { - context.reset(); -} - -/// Get the result of the context. -/// -/// # Arguments -/// -/// - `context`: a pointer to the [`Context`] object. -/// - `uuid_hex`: If not `NULL`, the UUID of the matched matcher will be stored. -/// - `matched_field`: If not `NULL`, the field name (C-style string) of the matched value will be stored. -/// - `matched_value`: If the `matched_field` is not `NULL`, the value of the matched field will be stored. -/// - `matched_value_len`: If the `matched_field` is not `NULL`, the length of the value of the matched field will be stored. -/// - `capture_names`: A pointer to an array of pointers to the capture names, each element is a non-C-style string pointer. -/// - `capture_names_len`: A pointer to an array of the length of each capture name. -/// - `capture_values`: A pointer to an array of pointers to the capture values, each element is a non-C-style string pointer. -/// - `capture_values_len`: A pointer to an array of the length of each capture value. -/// -/// # Returns -/// -/// Returns the number of captures that are stored in the context. -/// -/// # Lifetimes -/// -/// The string pointers stored in `matched_value`, `capture_names`, and `capture_values` -/// might be invalidated if any of the following operations are happened: -/// -/// - The `context` was deallocated. -/// - The `context` was reset by [`context_reset`]. -/// -/// # Panics -/// -/// This function will panic if the `matched_field` is not a valid UTF-8 string. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `context` must be a valid pointer returned by [`context_new`], -/// must be passed to [`router_execute`] before calling this function, -/// and must not be reset by [`context_reset`] before calling this function. -/// - If `uuid_hex` is not `NULL`, `uuid_hex` must be valid to read and write for -/// `16 * size_of::()` bytes, and it must be properly aligned. -/// - If `matched_field` is not `NULL`, -/// `matched_field` must be a vlaid pointer to a C-style string, -/// must be properly aligned, and must not have '\0' in the middle. -/// - If `matched_value` is not `NULL`, -/// `matched_value` must be valid to read and write for -/// `mem::size_of::<*const u8>()` bytes, and it must be properly aligned. -/// - If `matched_value` is not `NULL`, `matched_value_len` must be valid to read and write for -/// `size_of::()` bytes, and it must be properly aligned. -/// - If `uuid_hex` is not `NULL`, `capture_names` must be valid to read and write for -/// ` * size_of::<*const u8>()` bytes, and it must be properly aligned. -/// - If `uuid_hex` is not `NULL`, `capture_names_len` must be valid to read and write for -/// ` * size_of::()` bytes, and it must be properly aligned. -/// - If `uuid_hex` is not `NULL`, `capture_values` must be valid to read and write for -/// ` * size_of::<*const u8>()` bytes, and it must be properly aligned. -/// - If `uuid_hex` is not `NULL`, `capture_values_len` must be valid to read and write for -/// ` * size_of::()` bytes, and it must be properly aligned. -/// -/// Note: You should get the `` by calling this function and set every pointer -/// except the `context` to `NULL` to get the number of captures. -#[no_mangle] -pub unsafe extern "C" fn context_get_result( - context: &Context, - uuid_hex: *mut u8, - matched_field: *const i8, - matched_value: *mut *const u8, - matched_value_len: *mut usize, - capture_names: *mut *const u8, - capture_names_len: *mut usize, - capture_values: *mut *const u8, - capture_values_len: *mut usize, -) -> isize { - if context.result.is_none() { - return -1; - } - - if !uuid_hex.is_null() { - let uuid_hex = from_raw_parts_mut(uuid_hex, Hyphenated::LENGTH); - let res = context.result.as_ref().unwrap(); - - res.uuid.as_hyphenated().encode_lower(uuid_hex); - - if !matched_field.is_null() { - let matched_field = ffi::CStr::from_ptr(matched_field as *const c_char) - .to_str() - .unwrap(); - assert!(!matched_value.is_null()); - assert!(!matched_value_len.is_null()); - if let Some(Value::String(v)) = res.matches.get(matched_field) { - *matched_value = v.as_bytes().as_ptr(); - *matched_value_len = v.len(); - } else { - *matched_value_len = 0; - } - } - - if !context.result.as_ref().unwrap().captures.is_empty() { - assert!(*capture_names_len >= res.captures.len()); - assert!(*capture_names_len == *capture_values_len); - assert!(!capture_names.is_null()); - assert!(!capture_names_len.is_null()); - assert!(!capture_values.is_null()); - assert!(!capture_values_len.is_null()); - - let capture_names = from_raw_parts_mut(capture_names, *capture_names_len); - let capture_names_len = from_raw_parts_mut(capture_names_len, *capture_names_len); - let capture_values = from_raw_parts_mut(capture_values, *capture_values_len); - let capture_values_len = from_raw_parts_mut(capture_values_len, *capture_values_len); - - for (i, (k, v)) in res.captures.iter().enumerate() { - capture_names[i] = k.as_bytes().as_ptr(); - capture_names_len[i] = k.len(); - - capture_values[i] = v.as_bytes().as_ptr(); - capture_values_len[i] = v.len(); - } - } - } - - context - .result - .as_ref() - .unwrap() - .captures - .len() - .try_into() - .unwrap() -} - -impl Expression { - fn get_predicates(&self) -> Vec<&Predicate> { - let mut predicates = Vec::new(); - - fn visit<'a, 'b>(expr: &'a Expression, predicates: &mut Vec<&'b Predicate>) - where - 'a: 'b, - { - match expr { - Expression::Logical(l) => match l.as_ref() { - LogicalExpression::And(l, r) => { - visit(l, predicates); - visit(r, predicates); - } - LogicalExpression::Or(l, r) => { - visit(l, predicates); - visit(r, predicates); - } - LogicalExpression::Not(r) => { - visit(r, predicates); - } - }, - Expression::Predicate(p) => { - predicates.push(p); - } - } - } - - visit(self, &mut predicates); - - predicates - } -} - -bitflags! { - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] - #[repr(C)] - pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ { - const EQUALS = 1 << 0; - const NOT_EQUALS = 1 << 1; - const REGEX = 1 << 2; - const PREFIX = 1 << 3; - const POSTFIX = 1 << 4; - const GREATER = 1 << 5; - const GREATER_OR_EQUAL = 1 << 6; - const LESS = 1 << 7; - const LESS_OR_EQUAL = 1 << 8; - const IN = 1 << 9; - const NOT_IN = 1 << 10; - const CONTAINS = 1 << 11; - - const UNUSED = !(Self::EQUALS.bits() - | Self::NOT_EQUALS.bits() - | Self::REGEX.bits() - | Self::PREFIX.bits() - | Self::POSTFIX.bits() - | Self::GREATER.bits() - | Self::GREATER_OR_EQUAL.bits() - | Self::LESS.bits() - | Self::LESS_OR_EQUAL.bits() - | Self::IN.bits() - | Self::NOT_IN.bits() - | Self::CONTAINS.bits()); - } -} - -impl From<&BinaryOperator> for BinaryOperatorFlags { - fn from(op: &BinaryOperator) -> Self { - match op { - BinaryOperator::Equals => Self::EQUALS, - BinaryOperator::NotEquals => Self::NOT_EQUALS, - BinaryOperator::Regex => Self::REGEX, - BinaryOperator::Prefix => Self::PREFIX, - BinaryOperator::Postfix => Self::POSTFIX, - BinaryOperator::Greater => Self::GREATER, - BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL, - BinaryOperator::Less => Self::LESS, - BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL, - BinaryOperator::In => Self::IN, - BinaryOperator::NotIn => Self::NOT_IN, - BinaryOperator::Contains => Self::CONTAINS, - } - } -} - -pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0; -pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1; -pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; - -/// Validate the ATC expression with the schema. -/// -/// # Arguments -/// -/// - `atc`: the C-style string representing the ATC expression. -/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. -/// - `fields_buf`: a buffer to store the used fields. -/// - `fields_len`: a pointer to the length of the fields buffer. -/// - `fields_total`: a pointer for saving the total number of the fields. -/// - `operators`: a pointer for saving the used operators with bitflags. -/// - `errbuf`: a buffer to store the error message. -/// - `errbuf_len`: a pointer to the length of the error message buffer. -/// -/// # Returns -/// -/// Returns an integer value indicating the validation result: -/// - ATC_ROUTER_EXPRESSION_VALIDATE_OK(0) if validation is passed. -/// - ATC_ROUTER_EXPRESSION_VALIDATE_FAILED(1) if validation is failed. The `errbuf` and `errbuf_len` will be updated with the error message. -/// - ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL(2) if the provided fields buffer is not enough. -/// -/// If `fields_buf` is non-null, and `fields_len` is enough for the required buffer length, -/// it will write the used fields to the buffer, each terminated by '\0' and the total number of fields -/// to the `fields_total`, and `fields_len` will be updated with the total buffer length. -/// -/// If `fields_buf` is non-null, and `fields_len` is not enough for the required buffer length, -/// it will write the required buffer length to the `fields_len`, and the total number of fields -/// to the `fields_total`, and return `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. -/// -/// If `operators` is non-null, it will write the used operators with bitflags to the provided pointer. -/// The bitflags is defined by `BinaryOperatorFlags` and it must not contain any bits from `BinaryOperatorFlags::UNUSED`. -/// -/// -/// # Panics -/// -/// This function will panic when: -/// -/// - `atc` doesn't point to a valid C-style string. -/// - `fields_len` and `fields_total` are null when `fields_buf` is non-null. -/// -/// # Safety -/// -/// Violating any of the following constraints will result in undefined behavior: -/// -/// - `atc` must be a valid pointer to a C-style string, must be properly aligned, -/// and must not have '\0' in the middle. -/// - `schema` must be a valid pointer returned by [`schema_new`]. -/// - `fields_buf` must be a valid to write for `fields_len * size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `fields_len` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `fields_total` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `operators` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, -/// and it must be properly aligned. -/// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, -/// and it must be properly aligned. -/// - If `fields_buf` is non-null, `fields_len` and `fields_total` must be non-null. -/// for writing required buffer length and total number of fields. -#[no_mangle] -pub unsafe extern "C" fn expression_validate( - atc: *const u8, - schema: &Schema, - fields_buf: *mut u8, - fields_len: *mut usize, - fields_total: *mut usize, - operators: *mut u64, - errbuf: *mut u8, - errbuf_len: *mut usize, -) -> i64 { - use std::collections::HashSet; - - use crate::parser::parse; - use crate::semantics::Validate; - - let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); - let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); - - // Parse the expression - let result = parse(atc).map_err(|e| e.to_string()); - if let Err(e) = result { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; - return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; - } - // Unwrap is safe since we've already checked for error - let ast = result.unwrap(); - - // Validate expression with schema - if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; - return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; - } - - // Direct use GetPredicates trait to avoid unnecessary accesses - let predicates = ast.get_predicates(); - - // Get used fields - if !fields_buf.is_null() { - assert!( - !(fields_len.is_null() || fields_total.is_null()), - "fields_len and fields_total must be non-null when fields_buf is non-null" - ); - - let expr_fields = predicates - .iter() - .map(|p| p.lhs.var_name.as_str()) - .collect::>(); - let total_fields_length = expr_fields - .iter() - .map(|k| k.as_bytes().len() + 1) - .sum::(); - - if !fields_buf.is_null() { - if *fields_len < total_fields_length { - *fields_len = total_fields_length; - *fields_total = expr_fields.len(); - return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; - } - - let mut fields_buf_ptr = fields_buf; - for field in &expr_fields { - let field = ffi::CString::new(*field).unwrap(); - let field_slice = field.as_bytes_with_nul(); - let field_len = field_slice.len(); - let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); - fields_buf.copy_from_slice(field_slice); - fields_buf_ptr = fields_buf_ptr.add(field_len); - } - } - - *fields_len = total_fields_length; - *fields_total = expr_fields.len(); - } - - // Get used operators - if !operators.is_null() { - let mut ops = BinaryOperatorFlags::empty(); - for pred in &predicates { - ops |= BinaryOperatorFlags::from(&pred.op); - } - *operators = ops.bits(); - } - - ATC_ROUTER_EXPRESSION_VALIDATE_OK -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_long_error_message() { - unsafe { - let schema = Schema::default(); - let mut router = Router::new(&schema); - let uuid = ffi::CString::new("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c").unwrap(); - let junk = ffi::CString::new(vec![b'a'; ERR_BUF_MAX_LEN * 2]).unwrap(); - let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; - let mut errbuf_len = ERR_BUF_MAX_LEN; - - let result = router_add_matcher( - &mut router, - 1, - uuid.as_ptr() as *const i8, - junk.as_ptr() as *const i8, - errbuf.as_mut_ptr(), - &mut errbuf_len, - ); - assert_eq!(result, false); - assert_eq!(errbuf_len, ERR_BUF_MAX_LEN); - } - } - - #[test] - fn test_short_error_message() { - unsafe { - let schema = Schema::default(); - let mut router = Router::new(&schema); - let uuid = ffi::CString::new("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c").unwrap(); - let junk = ffi::CString::new("aaaa").unwrap(); - let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; - let mut errbuf_len = ERR_BUF_MAX_LEN; - - let result = router_add_matcher( - &mut router, - 1, - uuid.as_ptr() as *const i8, - junk.as_ptr() as *const i8, - errbuf.as_mut_ptr(), - &mut errbuf_len, - ); - assert_eq!(result, false); - assert!(errbuf_len < ERR_BUF_MAX_LEN); - } - } - - fn expr_validate_on( - schema: &Schema, - atc: &str, - fields_buf_size: usize, - ) -> Result<(Vec, u64), (i64, String)> { - let atc = ffi::CString::new(atc).unwrap(); - let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; - let mut errbuf_len = ERR_BUF_MAX_LEN; - - let mut fields_buf = vec![0u8; fields_buf_size]; - let mut fields_len = fields_buf.len(); - let mut fields_total = 0; - let mut operators = 0u64; - - let result = unsafe { - expression_validate( - atc.as_bytes().as_ptr(), - &schema, - fields_buf.as_mut_ptr(), - &mut fields_len, - &mut fields_total, - &mut operators, - errbuf.as_mut_ptr(), - &mut errbuf_len, - ) - }; - - match result { - ATC_ROUTER_EXPRESSION_VALIDATE_OK => { - let mut fields = Vec::::with_capacity(fields_total); - let mut p = 0; - for _ in 0..fields_total { - let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) }; - let len = field.to_bytes().len() + 1; - fields.push(field.to_string_lossy().to_string()); - p += len; - } - assert_eq!(fields_len, p, "Fields buffer length mismatch"); - fields.sort(); - Ok((fields, operators)) - } - ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => { - let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); - Err((result, err)) - } - ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())), - _ => panic!("Unknown error code"), - } - } - - #[test] - fn test_expression_validate_success() { - let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; - - let mut schema = Schema::default(); - schema.add_field("net.protocol", Type::String); - schema.add_field("net.dst.port", Type::Int); - schema.add_field("net.src.ip", Type::IpAddr); - schema.add_field("http.path", Type::String); - - let result = expr_validate_on(&schema, atc, 1024); - - assert!(result.is_ok(), "Validation failed"); - let (fields, ops) = result.unwrap(); // Unwrap is safe since we've already asserted it - assert_eq!( - ops, - (BinaryOperatorFlags::EQUALS - | BinaryOperatorFlags::REGEX - | BinaryOperatorFlags::IN - | BinaryOperatorFlags::NOT_IN - | BinaryOperatorFlags::CONTAINS) - .bits(), - "Operators mismatch" - ); - assert_eq!( - fields, - vec![ - "http.path".to_string(), - "net.dst.port".to_string(), - "net.protocol".to_string(), - "net.src.ip".to_string() - ], - "Fields mismatch" - ); - } - - #[test] - fn test_expression_validate_failed_parse() { - let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0) && http.path contains "hello""##; - - let mut schema = Schema::default(); - schema.add_field("net.protocol", Type::String); - schema.add_field("net.dst.port", Type::Int); - schema.add_field("net.src.ip", Type::IpAddr); - schema.add_field("http.path", Type::String); - - let result = expr_validate_on(&schema, atc, 1024); - - assert!(result.is_err(), "Validation unexcepted success"); - let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it - assert_eq!( - err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, - "Error code mismatch" - ); - assert_eq!( - err_message, - "In/NotIn operators only supports IP in CIDR".to_string(), - "Error message mismatch" - ); - } - - #[test] - fn test_expression_validate_failed_validate() { - let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; - - let mut schema = Schema::default(); - schema.add_field("net.protocol", Type::String); - schema.add_field("net.dst.port", Type::Int); - schema.add_field("net.src.ip", Type::IpAddr); - - let result = expr_validate_on(&schema, atc, 1024); - - assert!(result.is_err(), "Validation unexcepted success"); - let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it - assert_eq!( - err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, - "Error code mismatch" - ); - assert_eq!( - err_message, - "Unknown LHS field".to_string(), - "Error message mismatch" - ); - } - - #[test] - fn test_expression_validate_buf_too_small() { - let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; - - let mut schema = Schema::default(); - schema.add_field("net.protocol", Type::String); - schema.add_field("net.dst.port", Type::Int); - schema.add_field("net.src.ip", Type::IpAddr); - schema.add_field("http.path", Type::String); - - let result = expr_validate_on(&schema, atc, 10); - - assert!(result.is_err(), "Validation failed"); - let (err_code, _) = result.unwrap_err(); // Unwrap is safe since we've already asserted it - assert_eq!( - err_code, ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL, - "Error code mismatch" - ); - } -} diff --git a/src/ffi/context.rs b/src/ffi/context.rs new file mode 100644 index 00000000..bbeab135 --- /dev/null +++ b/src/ffi/context.rs @@ -0,0 +1,251 @@ +use crate::ast::Value; +use crate::context::Context; +use crate::ffi::{CValue, ERR_BUF_MAX_LEN}; +use crate::schema::Schema; +use std::cmp::min; +use std::ffi; +use std::os::raw::c_char; +use std::slice::from_raw_parts_mut; +use uuid::fmt::Hyphenated; + +/// Allocate a new context object associated with the schema. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `schema` must be a valid pointer returned by [`schema_new`]. +#[no_mangle] +pub unsafe extern "C" fn context_new(schema: &Schema) -> *mut Context { + Box::into_raw(Box::new(Context::new(schema))) +} + +/// Deallocate the context object. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `context` must be a valid pointer returned by [`context_new`]. +#[no_mangle] +pub unsafe extern "C" fn context_free(context: *mut Context) { + drop(Box::from_raw(context)); +} + +/// Add a value associated with a field to the context. +/// This is useful when you want to match a value against a field in the schema. +/// +/// # Arguments +/// +/// - `context`: a pointer to the [`Context`] object. +/// - `field`: the C-style string representing the field name. +/// - `value`: the value to be added to the context. +/// - `errbuf`: a buffer to store the error message. +/// - `errbuf_len`: a pointer to the length of the error message buffer. +/// +/// # Returns +/// +/// Returns `true` if the value was added successfully, otherwise `false`, +/// and the error message will be stored in the `errbuf`, +/// and the length of the error message will be stored in `errbuf_len`. +/// +/// # Errors +/// +/// This function will return `false` if the value could not be added to the context, +/// such as when a String value is not a valid UTF-8 string. +/// +/// # Panics +/// +/// This function will panic if the provided value does not match the schema. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// * `context` must be a valid pointer returned by [`context_new`]. +/// * `field` must be a valid pointer to a C-style string, +/// must be properply aligned, and must not have '\0' in the middle. +/// * `value` must be a valid pointer to a [`CValue`]. +/// * `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, +/// and it must be properly aligned. +/// * `errbuf_len` must be vlaid to read and write for `size_of::()` bytes, +/// and it must be properly aligned. +#[no_mangle] +pub unsafe extern "C" fn context_add_value( + context: &mut Context, + field: *const i8, + value: &CValue, + errbuf: *mut u8, + errbuf_len: *mut usize, +) -> bool { + let field = ffi::CStr::from_ptr(field as *const c_char) + .to_str() + .unwrap(); + let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); + + let value: Result = value.try_into(); + if let Err(e) = value { + let errlen = min(e.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); + *errbuf_len = errlen; + return false; + } + + context.add_value(field, value.unwrap()); + + true +} + +/// Reset the context so that it can be reused. +/// This is useful when you want to reuse the same context for multiple matches. +/// This will clear all the values that were added to the context, +/// but keep the memory allocated for the context. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `context` must be a valid pointer returned by [`context_new`]. +#[no_mangle] +pub unsafe extern "C" fn context_reset(context: &mut Context) { + context.reset(); +} + +/// Get the result of the context. +/// +/// # Arguments +/// +/// - `context`: a pointer to the [`Context`] object. +/// - `uuid_hex`: If not `NULL`, the UUID of the matched matcher will be stored. +/// - `matched_field`: If not `NULL`, the field name (C-style string) of the matched value will be stored. +/// - `matched_value`: If the `matched_field` is not `NULL`, the value of the matched field will be stored. +/// - `matched_value_len`: If the `matched_field` is not `NULL`, the length of the value of the matched field will be stored. +/// - `capture_names`: A pointer to an array of pointers to the capture names, each element is a non-C-style string pointer. +/// - `capture_names_len`: A pointer to an array of the length of each capture name. +/// - `capture_values`: A pointer to an array of pointers to the capture values, each element is a non-C-style string pointer. +/// - `capture_values_len`: A pointer to an array of the length of each capture value. +/// +/// # Returns +/// +/// Returns the number of captures that are stored in the context. +/// +/// # Lifetimes +/// +/// The string pointers stored in `matched_value`, `capture_names`, and `capture_values` +/// might be invalidated if any of the following operations are happened: +/// +/// - The `context` was deallocated. +/// - The `context` was reset by [`context_reset`]. +/// +/// # Panics +/// +/// This function will panic if the `matched_field` is not a valid UTF-8 string. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `context` must be a valid pointer returned by [`context_new`], +/// must be passed to [`router_execute`] before calling this function, +/// and must not be reset by [`context_reset`] before calling this function. +/// - If `uuid_hex` is not `NULL`, `uuid_hex` must be valid to read and write for +/// `16 * size_of::()` bytes, and it must be properly aligned. +/// - If `matched_field` is not `NULL`, +/// `matched_field` must be a vlaid pointer to a C-style string, +/// must be properly aligned, and must not have '\0' in the middle. +/// - If `matched_value` is not `NULL`, +/// `matched_value` must be valid to read and write for +/// `mem::size_of::<*const u8>()` bytes, and it must be properly aligned. +/// - If `matched_value` is not `NULL`, `matched_value_len` must be valid to read and write for +/// `size_of::()` bytes, and it must be properly aligned. +/// - If `uuid_hex` is not `NULL`, `capture_names` must be valid to read and write for +/// ` * size_of::<*const u8>()` bytes, and it must be properly aligned. +/// - If `uuid_hex` is not `NULL`, `capture_names_len` must be valid to read and write for +/// ` * size_of::()` bytes, and it must be properly aligned. +/// - If `uuid_hex` is not `NULL`, `capture_values` must be valid to read and write for +/// ` * size_of::<*const u8>()` bytes, and it must be properly aligned. +/// - If `uuid_hex` is not `NULL`, `capture_values_len` must be valid to read and write for +/// ` * size_of::()` bytes, and it must be properly aligned. +/// +/// Note: You should get the `` by calling this function and set every pointer +/// except the `context` to `NULL` to get the number of captures. +#[no_mangle] +pub unsafe extern "C" fn context_get_result( + context: &Context, + uuid_hex: *mut u8, + matched_field: *const i8, + matched_value: *mut *const u8, + matched_value_len: *mut usize, + capture_names: *mut *const u8, + capture_names_len: *mut usize, + capture_values: *mut *const u8, + capture_values_len: *mut usize, +) -> isize { + if context.result.is_none() { + return -1; + } + + if !uuid_hex.is_null() { + let uuid_hex = from_raw_parts_mut(uuid_hex, Hyphenated::LENGTH); + let res = context.result.as_ref().unwrap(); + + res.uuid.as_hyphenated().encode_lower(uuid_hex); + + if !matched_field.is_null() { + let matched_field = ffi::CStr::from_ptr(matched_field as *const c_char) + .to_str() + .unwrap(); + assert!(!matched_value.is_null()); + assert!(!matched_value_len.is_null()); + if let Some(Value::String(v)) = res.matches.get(matched_field) { + *matched_value = v.as_bytes().as_ptr(); + *matched_value_len = v.len(); + } else { + *matched_value_len = 0; + } + } + + if !context.result.as_ref().unwrap().captures.is_empty() { + assert!(*capture_names_len >= res.captures.len()); + assert!(*capture_names_len == *capture_values_len); + assert!(!capture_names.is_null()); + assert!(!capture_names_len.is_null()); + assert!(!capture_values.is_null()); + assert!(!capture_values_len.is_null()); + + let capture_names = from_raw_parts_mut(capture_names, *capture_names_len); + let capture_names_len = from_raw_parts_mut(capture_names_len, *capture_names_len); + let capture_values = from_raw_parts_mut(capture_values, *capture_values_len); + let capture_values_len = from_raw_parts_mut(capture_values_len, *capture_values_len); + + for (i, (k, v)) in res.captures.iter().enumerate() { + capture_names[i] = k.as_bytes().as_ptr(); + capture_names_len[i] = k.len(); + + capture_values[i] = v.as_bytes().as_ptr(); + capture_values_len[i] = v.len(); + } + } + } + + context + .result + .as_ref() + .unwrap() + .captures + .len() + .try_into() + .unwrap() +} diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs new file mode 100644 index 00000000..0cedd4fc --- /dev/null +++ b/src/ffi/expression.rs @@ -0,0 +1,409 @@ +use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate}; +use crate::ffi::ERR_BUF_MAX_LEN; +use crate::schema::Schema; +use bitflags::bitflags; +use std::cmp::min; +use std::ffi; +use std::os::raw::c_char; +use std::slice::from_raw_parts_mut; + +impl Expression { + fn get_predicates(&self) -> Vec<&Predicate> { + let mut predicates = Vec::new(); + + fn visit<'a, 'b>(expr: &'a Expression, predicates: &mut Vec<&'b Predicate>) + where + 'a: 'b, + { + match expr { + Expression::Logical(l) => match l.as_ref() { + LogicalExpression::And(l, r) => { + visit(l, predicates); + visit(r, predicates); + } + LogicalExpression::Or(l, r) => { + visit(l, predicates); + visit(r, predicates); + } + LogicalExpression::Not(r) => { + visit(r, predicates); + } + }, + Expression::Predicate(p) => { + predicates.push(p); + } + } + } + + visit(self, &mut predicates); + + predicates + } +} + +bitflags! { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + #[repr(C)] + pub struct BinaryOperatorFlags: u64 /* We can only have no more than 64 BinaryOperators */ { + const EQUALS = 1 << 0; + const NOT_EQUALS = 1 << 1; + const REGEX = 1 << 2; + const PREFIX = 1 << 3; + const POSTFIX = 1 << 4; + const GREATER = 1 << 5; + const GREATER_OR_EQUAL = 1 << 6; + const LESS = 1 << 7; + const LESS_OR_EQUAL = 1 << 8; + const IN = 1 << 9; + const NOT_IN = 1 << 10; + const CONTAINS = 1 << 11; + + const UNUSED = !(Self::EQUALS.bits() + | Self::NOT_EQUALS.bits() + | Self::REGEX.bits() + | Self::PREFIX.bits() + | Self::POSTFIX.bits() + | Self::GREATER.bits() + | Self::GREATER_OR_EQUAL.bits() + | Self::LESS.bits() + | Self::LESS_OR_EQUAL.bits() + | Self::IN.bits() + | Self::NOT_IN.bits() + | Self::CONTAINS.bits()); + } +} + +impl From<&BinaryOperator> for BinaryOperatorFlags { + fn from(op: &BinaryOperator) -> Self { + match op { + BinaryOperator::Equals => Self::EQUALS, + BinaryOperator::NotEquals => Self::NOT_EQUALS, + BinaryOperator::Regex => Self::REGEX, + BinaryOperator::Prefix => Self::PREFIX, + BinaryOperator::Postfix => Self::POSTFIX, + BinaryOperator::Greater => Self::GREATER, + BinaryOperator::GreaterOrEqual => Self::GREATER_OR_EQUAL, + BinaryOperator::Less => Self::LESS, + BinaryOperator::LessOrEqual => Self::LESS_OR_EQUAL, + BinaryOperator::In => Self::IN, + BinaryOperator::NotIn => Self::NOT_IN, + BinaryOperator::Contains => Self::CONTAINS, + } + } +} + +pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0; +pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1; +pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; + +/// Validate the ATC expression with the schema. +/// +/// # Arguments +/// +/// - `atc`: the C-style string representing the ATC expression. +/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. +/// - `fields_buf`: a buffer to store the used fields. +/// - `fields_len`: a pointer to the length of the fields buffer. +/// - `fields_total`: a pointer for saving the total number of the fields. +/// - `operators`: a pointer for saving the used operators with bitflags. +/// - `errbuf`: a buffer to store the error message. +/// - `errbuf_len`: a pointer to the length of the error message buffer. +/// +/// # Returns +/// +/// Returns an integer value indicating the validation result: +/// - ATC_ROUTER_EXPRESSION_VALIDATE_OK(0) if validation is passed. +/// - ATC_ROUTER_EXPRESSION_VALIDATE_FAILED(1) if validation is failed. The `errbuf` and `errbuf_len` will be updated with the error message. +/// - ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL(2) if the provided fields buffer is not enough. +/// +/// If `fields_buf` is non-null, and `fields_len` is enough for the required buffer length, +/// it will write the used fields to the buffer, each terminated by '\0' and the total number of fields +/// to the `fields_total`, and `fields_len` will be updated with the total buffer length. +/// +/// If `fields_buf` is non-null, and `fields_len` is not enough for the required buffer length, +/// it will write the required buffer length to the `fields_len`, and the total number of fields +/// to the `fields_total`, and return `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. +/// +/// If `operators` is non-null, it will write the used operators with bitflags to the provided pointer. +/// The bitflags is defined by `BinaryOperatorFlags` and it must not contain any bits from `BinaryOperatorFlags::UNUSED`. +/// +/// +/// # Panics +/// +/// This function will panic when: +/// +/// - `atc` doesn't point to a valid C-style string. +/// - `fields_len` and `fields_total` are null when `fields_buf` is non-null. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `atc` must be a valid pointer to a C-style string, must be properly aligned, +/// and must not have '\0' in the middle. +/// - `schema` must be a valid pointer returned by [`schema_new`]. +/// - `fields_buf` must be a valid to write for `fields_len * size_of::()` bytes, +/// and it must be properly aligned if non-null. +/// - `fields_len` must be a valid to write for `size_of::()` bytes, +/// and it must be properly aligned if non-null. +/// - `fields_total` must be a valid to write for `size_of::()` bytes, +/// and it must be properly aligned if non-null. +/// - `operators` must be a valid to write for `size_of::()` bytes, +/// and it must be properly aligned if non-null. +/// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, +/// and it must be properly aligned. +/// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, +/// and it must be properly aligned. +/// - If `fields_buf` is non-null, `fields_len` and `fields_total` must be non-null. +/// for writing required buffer length and total number of fields. +#[no_mangle] +pub unsafe extern "C" fn expression_validate( + atc: *const u8, + schema: &Schema, + fields_buf: *mut u8, + fields_len: *mut usize, + fields_total: *mut usize, + operators: *mut u64, + errbuf: *mut u8, + errbuf_len: *mut usize, +) -> i64 { + use std::collections::HashSet; + + use crate::parser::parse; + use crate::semantics::Validate; + + let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); + let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); + + // Parse the expression + let result = parse(atc).map_err(|e| e.to_string()); + if let Err(e) = result { + let errlen = min(e.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); + *errbuf_len = errlen; + return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; + } + // Unwrap is safe since we've already checked for error + let ast = result.unwrap(); + + // Validate expression with schema + if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) { + let errlen = min(e.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); + *errbuf_len = errlen; + return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; + } + + // Direct use GetPredicates trait to avoid unnecessary accesses + let predicates = ast.get_predicates(); + + // Get used fields + if !fields_buf.is_null() { + assert!( + !(fields_len.is_null() || fields_total.is_null()), + "fields_len and fields_total must be non-null when fields_buf is non-null" + ); + + let expr_fields = predicates + .iter() + .map(|p| p.lhs.var_name.as_str()) + .collect::>(); + let total_fields_length = expr_fields + .iter() + .map(|k| k.as_bytes().len() + 1) + .sum::(); + + if !fields_buf.is_null() { + if *fields_len < total_fields_length { + *fields_len = total_fields_length; + *fields_total = expr_fields.len(); + return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; + } + + let mut fields_buf_ptr = fields_buf; + for field in &expr_fields { + let field = ffi::CString::new(*field).unwrap(); + let field_slice = field.as_bytes_with_nul(); + let field_len = field_slice.len(); + let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); + fields_buf.copy_from_slice(field_slice); + fields_buf_ptr = fields_buf_ptr.add(field_len); + } + } + + *fields_len = total_fields_length; + *fields_total = expr_fields.len(); + } + + // Get used operators + if !operators.is_null() { + let mut ops = BinaryOperatorFlags::empty(); + for pred in &predicates { + ops |= BinaryOperatorFlags::from(&pred.op); + } + *operators = ops.bits(); + } + + ATC_ROUTER_EXPRESSION_VALIDATE_OK +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::Type; + + fn expr_validate_on( + schema: &Schema, + atc: &str, + fields_buf_size: usize, + ) -> Result<(Vec, u64), (i64, String)> { + let atc = ffi::CString::new(atc).unwrap(); + let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; + let mut errbuf_len = ERR_BUF_MAX_LEN; + + let mut fields_buf = vec![0u8; fields_buf_size]; + let mut fields_len = fields_buf.len(); + let mut fields_total = 0; + let mut operators = 0u64; + + let result = unsafe { + expression_validate( + atc.as_bytes().as_ptr(), + &schema, + fields_buf.as_mut_ptr(), + &mut fields_len, + &mut fields_total, + &mut operators, + errbuf.as_mut_ptr(), + &mut errbuf_len, + ) + }; + + match result { + ATC_ROUTER_EXPRESSION_VALIDATE_OK => { + let mut fields = Vec::::with_capacity(fields_total); + let mut p = 0; + for _ in 0..fields_total { + let field = unsafe { ffi::CStr::from_ptr(fields_buf[p..].as_ptr().cast()) }; + let len = field.to_bytes().len() + 1; + fields.push(field.to_string_lossy().to_string()); + p += len; + } + assert_eq!(fields_len, p, "Fields buffer length mismatch"); + fields.sort(); + Ok((fields, operators)) + } + ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => { + let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); + Err((result, err)) + } + ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())), + _ => panic!("Unknown error code"), + } + } + + #[test] + fn test_expression_validate_success() { + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + schema.add_field("http.path", Type::String); + + let result = expr_validate_on(&schema, atc, 1024); + + assert!(result.is_ok(), "Validation failed"); + let (fields, ops) = result.unwrap(); // Unwrap is safe since we've already asserted it + assert_eq!( + ops, + (BinaryOperatorFlags::EQUALS + | BinaryOperatorFlags::REGEX + | BinaryOperatorFlags::IN + | BinaryOperatorFlags::NOT_IN + | BinaryOperatorFlags::CONTAINS) + .bits(), + "Operators mismatch" + ); + assert_eq!( + fields, + vec![ + "http.path".to_string(), + "net.dst.port".to_string(), + "net.protocol".to_string(), + "net.src.ip".to_string() + ], + "Fields mismatch" + ); + } + + #[test] + fn test_expression_validate_failed_parse() { + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + schema.add_field("http.path", Type::String); + + let result = expr_validate_on(&schema, atc, 1024); + + assert!(result.is_err(), "Validation unexcepted success"); + let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it + assert_eq!( + err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, + "Error code mismatch" + ); + assert_eq!( + err_message, + "In/NotIn operators only supports IP in CIDR".to_string(), + "Error message mismatch" + ); + } + + #[test] + fn test_expression_validate_failed_validate() { + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + + let result = expr_validate_on(&schema, atc, 1024); + + assert!(result.is_err(), "Validation unexcepted success"); + let (err_code, err_message) = result.unwrap_err(); // Unwrap is safe since we've already asserted it + assert_eq!( + err_code, ATC_ROUTER_EXPRESSION_VALIDATE_FAILED, + "Error code mismatch" + ); + assert_eq!( + err_message, + "Unknown LHS field".to_string(), + "Error message mismatch" + ); + } + + #[test] + fn test_expression_validate_buf_too_small() { + let atc = r##"net.protocol ~ "^https?$" && net.dst.port == 80 && (net.src.ip not in 10.0.0.0/16 || net.src.ip in 10.0.1.0/24) && http.path contains "hello""##; + + let mut schema = Schema::default(); + schema.add_field("net.protocol", Type::String); + schema.add_field("net.dst.port", Type::Int); + schema.add_field("net.src.ip", Type::IpAddr); + schema.add_field("http.path", Type::String); + + let result = expr_validate_on(&schema, atc, 10); + + assert!(result.is_err(), "Validation failed"); + let (err_code, _) = result.unwrap_err(); // Unwrap is safe since we've already asserted it + assert_eq!( + err_code, ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL, + "Error code mismatch" + ); + } +} diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs new file mode 100644 index 00000000..c7d01b1a --- /dev/null +++ b/src/ffi/mod.rs @@ -0,0 +1,58 @@ +pub mod context; +pub mod expression; +pub mod router; +pub mod schema; + +use crate::ast::Value; +use cidr::IpCidr; +use std::convert::TryFrom; +use std::ffi; +use std::net::IpAddr; +use std::os::raw::c_char; +use std::slice::from_raw_parts; + +pub const ERR_BUF_MAX_LEN: usize = 4096; + +#[derive(Debug)] +#[repr(C)] +pub enum CValue { + Str(*const u8, usize), + IpCidr(*const u8), + IpAddr(*const u8), + Int(i64), +} + +impl TryFrom<&CValue> for Value { + type Error = String; + + fn try_from(v: &CValue) -> Result { + Ok(match v { + CValue::Str(s, len) => Self::String(unsafe { + std::str::from_utf8(from_raw_parts(*s, *len)) + .map_err(|e| e.to_string())? + .to_string() + }), + CValue::IpCidr(s) => Self::IpCidr( + unsafe { + ffi::CStr::from_ptr(*s as *const c_char) + .to_str() + .map_err(|e| e.to_string())? + .to_string() + } + .parse::() + .map_err(|e| e.to_string())?, + ), + CValue::IpAddr(s) => Self::IpAddr( + unsafe { + ffi::CStr::from_ptr(*s as *const c_char) + .to_str() + .map_err(|e| e.to_string())? + .to_string() + } + .parse::() + .map_err(|e| e.to_string())?, + ), + CValue::Int(i) => Self::Int(*i), + }) + } +} diff --git a/src/ffi/router.rs b/src/ffi/router.rs new file mode 100644 index 00000000..c7e9f65c --- /dev/null +++ b/src/ffi/router.rs @@ -0,0 +1,288 @@ +use crate::context::Context; +use crate::ffi::ERR_BUF_MAX_LEN; +use crate::router::Router; +use crate::schema::Schema; +use std::cmp::min; +use std::ffi; +use std::os::raw::c_char; +use std::slice::from_raw_parts_mut; +use uuid::Uuid; + +/// Create a new router object associated with the schema. +/// +/// # Arguments +/// +/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `schema` must be a valid pointer returned by [`schema_new`]. +#[no_mangle] +pub unsafe extern "C" fn router_new(schema: &Schema) -> *mut Router { + Box::into_raw(Box::new(Router::new(schema))) +} + +/// Deallocate the router object. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `router` must be a valid pointer returned by [`router_new`]. +#[no_mangle] +pub unsafe extern "C" fn router_free(router: *mut Router) { + drop(Box::from_raw(router)); +} + +/// Add a new matcher to the router. +/// +/// # Arguments +/// +/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. +/// - `priority`: the priority of the matcher, higher value means higher priority, +/// and the matcher with the highest priority will be executed first. +/// - `uuid`: the C-style string representing the UUID of the matcher. +/// - `atc`: the C-style string representing the ATC expression. +/// - `errbuf`: a buffer to store the error message. +/// - `errbuf_len`: a pointer to the length of the error message buffer. +/// +/// # Returns +/// +/// Returns `true` if the matcher was added successfully, otherwise `false`, +/// and the error message will be stored in the `errbuf`, +/// and the length of the error message will be stored in `errbuf_len`. +/// +/// # Errors +/// +/// This function will return `false` if the matcher could not be added to the router, +/// such as duplicate UUID, and invalid ATC expression. +/// +/// # Panics +/// +/// This function will panic when: +/// +/// - `uuid` doesn't point to a ASCII sequence representing a valid 128-bit UUID. +/// - `atc` doesn't point to a valid C-style string. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `router` must be a valid pointer returned by [`router_new`]. +/// - `uuid` must be a valid pointer to a C-style string, must be properly aligned, +/// and must not have '\0' in the middle. +/// - `atc` must be a valid pointer to a C-style string, must be properly aligned, +/// and must not have '\0' in the middle. +/// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, +/// and it must be properly aligned. +/// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, +/// and it must be properly aligned. +#[no_mangle] +pub unsafe extern "C" fn router_add_matcher( + router: &mut Router, + priority: usize, + uuid: *const i8, + atc: *const i8, + errbuf: *mut u8, + errbuf_len: *mut usize, +) -> bool { + let uuid = ffi::CStr::from_ptr(uuid as *const c_char).to_str().unwrap(); + let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); + let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); + + let uuid = Uuid::try_parse(uuid).expect("invalid UUID format"); + + if let Err(e) = router.add_matcher(priority, uuid, atc) { + let errlen = min(e.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); + *errbuf_len = errlen; + return false; + } + + true +} + +/// Remove a matcher from the router. +/// +/// # Arguments +/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. +/// - `priority`: the priority of the matcher to be removed. +/// - `uuid`: the C-style string representing the UUID of the matcher to be removed. +/// +/// # Returns +/// +/// Returns `true` if the matcher was removed successfully, otherwise `false`, +/// such as when the matcher with the specified UUID doesn't exist or +/// the priority doesn't match the UUID. +/// +/// # Panics +/// +/// This function will panic when `uuid` doesn't point to a ASCII sequence +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `router` must be a valid pointer returned by [`router_new`]. +/// - `uuid` must be a valid pointer to a C-style string, must be properly aligned, +/// and must not have '\0' in the middle. +#[no_mangle] +pub unsafe extern "C" fn router_remove_matcher( + router: &mut Router, + priority: usize, + uuid: *const i8, +) -> bool { + let uuid = ffi::CStr::from_ptr(uuid as *const c_char).to_str().unwrap(); + let uuid = Uuid::try_parse(uuid).expect("invalid UUID format"); + + router.remove_matcher(priority, uuid) +} + +/// Execute the router with the context. +/// +/// # Arguments +/// +/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. +/// - `context`: a pointer to the [`Context`] object. +/// +/// # Returns +/// +/// Returns `true` if found a match, `false` means no match found. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `router` must be a valid pointer returned by [`router_new`]. +/// - `context` must be a valid pointer returned by [`context_new`], +/// and must be reset by [`context_reset`] before calling this function +/// if you want to reuse the same context for multiple matches. +#[no_mangle] +pub unsafe extern "C" fn router_execute(router: &Router, context: &mut Context) -> bool { + router.execute(context) +} + +/// Get the de-duplicated fields that are actually used in the router. +/// This is useful when you want to know what fields are actually used in the router, +/// so you can generate their values on-demand. +/// +/// # Arguments +/// +/// - `router`: a pointer to the [`Router`] object returned by [`router_new`]. +/// - `fields`: a pointer to an array of pointers to the field names +/// (NOT C-style strings) that are actually used in the router, which will be filled in. +/// if `fields` is `NULL`, this function will only return the number of fields used +/// in the router. +/// - `fields_len`: a pointer to an array of the length of each field name. +/// +/// # Lifetimes +/// +/// The string pointers stored in `fields` might be invalidated if any of the following +/// operations are happened: +/// +/// - The `router` was deallocated. +/// - A new matcher was added to the `router`. +/// - A matcher was removed from the `router`. +/// +/// # Returns +/// +/// Returns the number of fields that are actually used in the router. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `router` must be a valid pointer returned by [`router_new`]. +/// - If `fields` is not `NULL`, `fields` must be valid to read and write for +/// `fields_len * size_of::<*const u8>()` bytes, and it must be properly aligned. +/// - If `fields` is not `NULL`, `fields_len` must be valid to read and write for +/// `size_of::()` bytes, and it must be properly aligned. +/// - DO NOT write the memory pointed by the elements of `fields`. +/// - DO NOT access the memory pointed by the elements of `fields` +/// after it becomes invalid, see the `Lifetimes` section. +#[no_mangle] +pub unsafe extern "C" fn router_get_fields( + router: &Router, + fields: *mut *const u8, + fields_len: *mut usize, +) -> usize { + if !fields.is_null() { + assert!(!fields_len.is_null()); + assert!(*fields_len >= router.fields.len()); + + let fields = from_raw_parts_mut(fields, *fields_len); + let fields_len = from_raw_parts_mut(fields_len, *fields_len); + + for (i, k) in router.fields.keys().enumerate() { + fields[i] = k.as_bytes().as_ptr(); + fields_len[i] = k.len() + } + } + + router.fields.len() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_long_error_message() { + unsafe { + let schema = Schema::default(); + let mut router = Router::new(&schema); + let uuid = ffi::CString::new("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c").unwrap(); + let junk = ffi::CString::new(vec![b'a'; ERR_BUF_MAX_LEN * 2]).unwrap(); + let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; + let mut errbuf_len = ERR_BUF_MAX_LEN; + + let result = router_add_matcher( + &mut router, + 1, + uuid.as_ptr() as *const i8, + junk.as_ptr() as *const i8, + errbuf.as_mut_ptr(), + &mut errbuf_len, + ); + assert_eq!(result, false); + assert_eq!(errbuf_len, ERR_BUF_MAX_LEN); + } + } + + #[test] + fn test_short_error_message() { + unsafe { + let schema = Schema::default(); + let mut router = Router::new(&schema); + let uuid = ffi::CString::new("a921a9aa-ec0e-4cf3-a6cc-1aa5583d150c").unwrap(); + let junk = ffi::CString::new("aaaa").unwrap(); + let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; + let mut errbuf_len = ERR_BUF_MAX_LEN; + + let result = router_add_matcher( + &mut router, + 1, + uuid.as_ptr() as *const i8, + junk.as_ptr() as *const i8, + errbuf.as_mut_ptr(), + &mut errbuf_len, + ); + assert_eq!(result, false); + assert!(errbuf_len < ERR_BUF_MAX_LEN); + } + } +} diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs new file mode 100644 index 00000000..ef51e6e3 --- /dev/null +++ b/src/ffi/schema.rs @@ -0,0 +1,54 @@ +use crate::ast::Type; +use crate::schema::Schema; +use std::ffi; +use std::os::raw::c_char; + +#[no_mangle] +pub extern "C" fn schema_new() -> *mut Schema { + Box::into_raw(Box::default()) +} + +/// Deallocate the schema object. +/// +/// # Errors +/// +/// This function never fails. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `schema` must be a valid pointer returned by [`schema_new`]. +#[no_mangle] +pub unsafe extern "C" fn schema_free(schema: *mut Schema) { + drop(Box::from_raw(schema)); +} + +/// Add a new field with the specified type to the schema. +/// +/// # Arguments +/// +/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. +/// - `field`: the C-style string representing the field name. +/// - `typ`: the type of the field. +/// +/// # Panics +/// +/// This function will panic if the C-style string +/// pointed by `field` is not a valid UTF-8 string. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// - `schema` must be a valid pointer returned by [`schema_new`]. +/// - `field` must be a valid pointer to a C-style string, must be properly aligned, +/// and must not have '\0' in the middle. +#[no_mangle] +pub unsafe extern "C" fn schema_add_field(schema: &mut Schema, field: *const i8, typ: Type) { + let field = ffi::CStr::from_ptr(field as *const c_char) + .to_str() + .unwrap(); + + schema.add_field(field, typ) +} From 11f0bfe0f32c65fd73fa70baef3795e1044c23aa Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 30 Oct 2024 17:29:23 +0800 Subject: [PATCH 14/23] feat(expression-compatibility): doc-fix --- src/ffi/expression.rs | 72 +++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 41 deletions(-) diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index 0cedd4fc..5c322eea 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -96,66 +96,56 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0; pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1; pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; -/// Validate the ATC expression with the schema. +/// Validates an ATC expression against a schema. /// /// # Arguments /// -/// - `atc`: the C-style string representing the ATC expression. -/// - `schema`: a valid pointer to the [`Schema`] object returned by [`schema_new`]. -/// - `fields_buf`: a buffer to store the used fields. -/// - `fields_len`: a pointer to the length of the fields buffer. -/// - `fields_total`: a pointer for saving the total number of the fields. -/// - `operators`: a pointer for saving the used operators with bitflags. -/// - `errbuf`: a buffer to store the error message. +/// - `atc`: a C-style string representing the ATC expression. +/// - `schema`: a valid pointer to a [`Schema`] object, as returned by [`schema_new`]. +/// - `fields_buf`: a buffer for storing the fields used in the expression. +/// - `fields_len`: a pointer to the length of `fields_buf`. +/// - `fields_total`: a pointer for storing the total number of fields. +/// - `operators`: a pointer for storing the bitflags representing used operators. +/// - `errbuf`: a buffer to store any error messages. /// - `errbuf_len`: a pointer to the length of the error message buffer. /// /// # Returns /// -/// Returns an integer value indicating the validation result: -/// - ATC_ROUTER_EXPRESSION_VALIDATE_OK(0) if validation is passed. -/// - ATC_ROUTER_EXPRESSION_VALIDATE_FAILED(1) if validation is failed. The `errbuf` and `errbuf_len` will be updated with the error message. -/// - ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL(2) if the provided fields buffer is not enough. +/// An integer indicating the validation result: +/// - `ATC_ROUTER_EXPRESSION_VALIDATE_OK` (0): Validation succeeded. +/// - `ATC_ROUTER_EXPRESSION_VALIDATE_FAILED` (1): Validation failed; `errbuf` and `errbuf_len` will be updated with an error message. +/// - `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL` (2): The provided `fields_buf` is too small. /// -/// If `fields_buf` is non-null, and `fields_len` is enough for the required buffer length, -/// it will write the used fields to the buffer, each terminated by '\0' and the total number of fields -/// to the `fields_total`, and `fields_len` will be updated with the total buffer length. +/// If `fields_buf` is non-null and `fields_len` is sufficient, this function writes the used fields to `fields_buf`, +/// each field terminated by `\0`. It updates `fields_len` with the required buffer length and stores the total number of fields in `fields_total`. /// -/// If `fields_buf` is non-null, and `fields_len` is not enough for the required buffer length, -/// it will write the required buffer length to the `fields_len`, and the total number of fields -/// to the `fields_total`, and return `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. -/// -/// If `operators` is non-null, it will write the used operators with bitflags to the provided pointer. -/// The bitflags is defined by `BinaryOperatorFlags` and it must not contain any bits from `BinaryOperatorFlags::UNUSED`. +/// If `fields_buf` is non-null but `fields_len` is insufficient, it writes the required buffer length to `fields_len` +/// and the total number of fields to `fields_total`, then returns `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. /// +/// If `operators` is non-null, it writes the used operators as bitflags to the provided pointer. +/// Bitflags are defined by `BinaryOperatorFlags` and must exclude bits from `BinaryOperatorFlags::UNUSED`. /// /// # Panics /// -/// This function will panic when: +/// This function will panic if: /// -/// - `atc` doesn't point to a valid C-style string. -/// - `fields_len` and `fields_total` are null when `fields_buf` is non-null. +/// - `atc` does not point to a valid C-style string. +/// - `fields_len` or `fields_total` are null when `fields_buf` is non-null. /// /// # Safety /// -/// Violating any of the following constraints will result in undefined behavior: +/// Violating any of the following constraints results in undefined behavior: /// -/// - `atc` must be a valid pointer to a C-style string, must be properly aligned, -/// and must not have '\0' in the middle. +/// - `atc` must be a valid pointer to a C-style string, properly aligned, and must not contain an internal `\0`. /// - `schema` must be a valid pointer returned by [`schema_new`]. -/// - `fields_buf` must be a valid to write for `fields_len * size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `fields_len` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `fields_total` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `operators` must be a valid to write for `size_of::()` bytes, -/// and it must be properly aligned if non-null. -/// - `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, -/// and it must be properly aligned. -/// - `errbuf_len` must be valid to read and write for `size_of::()` bytes, -/// and it must be properly aligned. -/// - If `fields_buf` is non-null, `fields_len` and `fields_total` must be non-null. -/// for writing required buffer length and total number of fields. +/// - `fields_buf`, if non-null, must be valid for writing `fields_len * size_of::()` bytes and properly aligned. +/// - `fields_len` must be a valid pointer to write `size_of::()` bytes and properly aligned. +/// - `fields_total` must be a valid pointer to write `size_of::()` bytes and properly aligned. +/// - `operators` must be a valid pointer to write `size_of::()` bytes and properly aligned. +/// - `errbuf` must be valid for reading and writing `errbuf_len * size_of::()` bytes and properly aligned. +/// - `errbuf_len` must be a valid pointer for reading and writing `size_of::()` bytes and properly aligned. +/// - If `fields_buf` is non-null, then `fields_len` and `fields_total` must also be non-null to store the buffer length used and total field count. + #[no_mangle] pub unsafe extern "C" fn expression_validate( atc: *const u8, From 9c3cf543e4d05e4abd9231e85a0453369fea8dad Mon Sep 17 00:00:00 2001 From: Shiroko Date: Sat, 2 Nov 2024 08:06:08 +0000 Subject: [PATCH 15/23] feat(expression-compatibility): rename `fields_len` to `fields_buf_len`. --- src/ffi/expression.rs | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index 5c322eea..e15cf8e1 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -103,7 +103,7 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// - `atc`: a C-style string representing the ATC expression. /// - `schema`: a valid pointer to a [`Schema`] object, as returned by [`schema_new`]. /// - `fields_buf`: a buffer for storing the fields used in the expression. -/// - `fields_len`: a pointer to the length of `fields_buf`. +/// - `fields_buf_len`: a pointer to the length of `fields_buf`. /// - `fields_total`: a pointer for storing the total number of fields. /// - `operators`: a pointer for storing the bitflags representing used operators. /// - `errbuf`: a buffer to store any error messages. @@ -116,10 +116,10 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// - `ATC_ROUTER_EXPRESSION_VALIDATE_FAILED` (1): Validation failed; `errbuf` and `errbuf_len` will be updated with an error message. /// - `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL` (2): The provided `fields_buf` is too small. /// -/// If `fields_buf` is non-null and `fields_len` is sufficient, this function writes the used fields to `fields_buf`, -/// each field terminated by `\0`. It updates `fields_len` with the required buffer length and stores the total number of fields in `fields_total`. +/// If `fields_buf` is non-null and `fields_buf_len` is sufficient, this function writes the used fields to `fields_buf`, +/// each field terminated by `\0`. It updates `fields_buf_len` with the required buffer length and stores the total number of fields in `fields_total`. /// -/// If `fields_buf` is non-null but `fields_len` is insufficient, it writes the required buffer length to `fields_len` +/// If `fields_buf` is non-null but `fields_buf_len` is insufficient, it writes the required buffer length to `fields_buf_len` /// and the total number of fields to `fields_total`, then returns `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. /// /// If `operators` is non-null, it writes the used operators as bitflags to the provided pointer. @@ -130,7 +130,7 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// This function will panic if: /// /// - `atc` does not point to a valid C-style string. -/// - `fields_len` or `fields_total` are null when `fields_buf` is non-null. +/// - `fields_buf_len` or `fields_total` are null when `fields_buf` is non-null. /// /// # Safety /// @@ -138,20 +138,20 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// /// - `atc` must be a valid pointer to a C-style string, properly aligned, and must not contain an internal `\0`. /// - `schema` must be a valid pointer returned by [`schema_new`]. -/// - `fields_buf`, if non-null, must be valid for writing `fields_len * size_of::()` bytes and properly aligned. -/// - `fields_len` must be a valid pointer to write `size_of::()` bytes and properly aligned. +/// - `fields_buf`, if non-null, must be valid for writing `fields_buf_len * size_of::()` bytes and properly aligned. +/// - `fields_buf_len` must be a valid pointer to write `size_of::()` bytes and properly aligned. /// - `fields_total` must be a valid pointer to write `size_of::()` bytes and properly aligned. /// - `operators` must be a valid pointer to write `size_of::()` bytes and properly aligned. /// - `errbuf` must be valid for reading and writing `errbuf_len * size_of::()` bytes and properly aligned. /// - `errbuf_len` must be a valid pointer for reading and writing `size_of::()` bytes and properly aligned. -/// - If `fields_buf` is non-null, then `fields_len` and `fields_total` must also be non-null to store the buffer length used and total field count. +/// - If `fields_buf` is non-null, then `fields_buf_len` and `fields_total` must also be non-null to store the buffer length used and total field count. #[no_mangle] pub unsafe extern "C" fn expression_validate( atc: *const u8, schema: &Schema, fields_buf: *mut u8, - fields_len: *mut usize, + fields_buf_len: *mut usize, fields_total: *mut usize, operators: *mut u64, errbuf: *mut u8, @@ -190,8 +190,8 @@ pub unsafe extern "C" fn expression_validate( // Get used fields if !fields_buf.is_null() { assert!( - !(fields_len.is_null() || fields_total.is_null()), - "fields_len and fields_total must be non-null when fields_buf is non-null" + !(fields_buf_len.is_null() || fields_total.is_null()), + "fields_buf_len and fields_total must be non-null when fields_buf is non-null" ); let expr_fields = predicates @@ -204,8 +204,8 @@ pub unsafe extern "C" fn expression_validate( .sum::(); if !fields_buf.is_null() { - if *fields_len < total_fields_length { - *fields_len = total_fields_length; + if *fields_buf_len < total_fields_length { + *fields_buf_len = total_fields_length; *fields_total = expr_fields.len(); return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; } @@ -221,7 +221,7 @@ pub unsafe extern "C" fn expression_validate( } } - *fields_len = total_fields_length; + *fields_buf_len = total_fields_length; *fields_total = expr_fields.len(); } @@ -252,7 +252,7 @@ mod tests { let mut errbuf_len = ERR_BUF_MAX_LEN; let mut fields_buf = vec![0u8; fields_buf_size]; - let mut fields_len = fields_buf.len(); + let mut fields_buf_len = fields_buf.len(); let mut fields_total = 0; let mut operators = 0u64; @@ -261,7 +261,7 @@ mod tests { atc.as_bytes().as_ptr(), &schema, fields_buf.as_mut_ptr(), - &mut fields_len, + &mut fields_buf_len, &mut fields_total, &mut operators, errbuf.as_mut_ptr(), @@ -279,7 +279,7 @@ mod tests { fields.push(field.to_string_lossy().to_string()); p += len; } - assert_eq!(fields_len, p, "Fields buffer length mismatch"); + assert_eq!(fields_buf_len, p, "Fields buffer length mismatch"); fields.sort(); Ok((fields, operators)) } From ead5337d1bfdcc014d83989dc95f45dbd92e2065 Mon Sep 17 00:00:00 2001 From: Haoxuan Date: Sat, 2 Nov 2024 16:10:21 +0800 Subject: [PATCH 16/23] feat(expression-compatibility): update doc from code-review Co-authored-by: Datong Sun --- src/ffi/expression.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index e15cf8e1..bd4b2691 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -104,7 +104,7 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// - `schema`: a valid pointer to a [`Schema`] object, as returned by [`schema_new`]. /// - `fields_buf`: a buffer for storing the fields used in the expression. /// - `fields_buf_len`: a pointer to the length of `fields_buf`. -/// - `fields_total`: a pointer for storing the total number of fields. +/// - `fields_total`: a pointer for storing the total number of unique fields used in the expression. /// - `operators`: a pointer for storing the bitflags representing used operators. /// - `errbuf`: a buffer to store any error messages. /// - `errbuf_len`: a pointer to the length of the error message buffer. From 884e7cce4be2dce265d41d0c908ddad2a16c5c2e Mon Sep 17 00:00:00 2001 From: Shiroko Date: Sat, 2 Nov 2024 08:15:29 +0000 Subject: [PATCH 17/23] feat(expression-compatibility): remove incorrect panic statement for `atc` --- src/ffi/expression.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index bd4b2691..0a310b0a 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -129,7 +129,6 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// /// This function will panic if: /// -/// - `atc` does not point to a valid C-style string. /// - `fields_buf_len` or `fields_total` are null when `fields_buf` is non-null. /// /// # Safety From 5eda1ae22e23407ee39641a8930ee0e8e1e64190 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Sat, 2 Nov 2024 08:56:17 +0000 Subject: [PATCH 18/23] feat(expression-compatibility): simplify API to not allow null parameter. --- src/ffi/expression.rs | 85 +++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 51 deletions(-) diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index 0a310b0a..de6d49e4 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -96,7 +96,7 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_OK: i64 = 0; pub const ATC_ROUTER_EXPRESSION_VALIDATE_FAILED: i64 = 1; pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; -/// Validates an ATC expression against a schema. +/// Validates an ATC expression against a schema and get its elements. /// /// # Arguments /// @@ -116,20 +116,15 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// - `ATC_ROUTER_EXPRESSION_VALIDATE_FAILED` (1): Validation failed; `errbuf` and `errbuf_len` will be updated with an error message. /// - `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL` (2): The provided `fields_buf` is too small. /// -/// If `fields_buf` is non-null and `fields_buf_len` is sufficient, this function writes the used fields to `fields_buf`, -/// each field terminated by `\0`. It updates `fields_buf_len` with the required buffer length and stores the total number of fields in `fields_total`. +/// If `fields_buf_len` indicates that `fields_buf` is sufficient, this function writes the used fields to `fields_buf`, each field terminated by `\0`. +/// It updates `fields_buf_len` with the required buffer length and stores the total number of fields in `fields_total`. /// -/// If `fields_buf` is non-null but `fields_buf_len` is insufficient, it writes the required buffer length to `fields_buf_len` +/// If `fields_buf_len` indicates that `fields_buf` is insufficient, it writes the required buffer length to `fields_buf_len` /// and the total number of fields to `fields_total`, then returns `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. /// -/// If `operators` is non-null, it writes the used operators as bitflags to the provided pointer. +/// It writes the used operators as bitflags to `operators`. /// Bitflags are defined by `BinaryOperatorFlags` and must exclude bits from `BinaryOperatorFlags::UNUSED`. /// -/// # Panics -/// -/// This function will panic if: -/// -/// - `fields_buf_len` or `fields_total` are null when `fields_buf` is non-null. /// /// # Safety /// @@ -137,13 +132,12 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// /// - `atc` must be a valid pointer to a C-style string, properly aligned, and must not contain an internal `\0`. /// - `schema` must be a valid pointer returned by [`schema_new`]. -/// - `fields_buf`, if non-null, must be valid for writing `fields_buf_len * size_of::()` bytes and properly aligned. +/// - `fields_buf`, must be valid for writing `fields_buf_len * size_of::()` bytes and properly aligned. /// - `fields_buf_len` must be a valid pointer to write `size_of::()` bytes and properly aligned. /// - `fields_total` must be a valid pointer to write `size_of::()` bytes and properly aligned. /// - `operators` must be a valid pointer to write `size_of::()` bytes and properly aligned. /// - `errbuf` must be valid for reading and writing `errbuf_len * size_of::()` bytes and properly aligned. /// - `errbuf_len` must be a valid pointer for reading and writing `size_of::()` bytes and properly aligned. -/// - If `fields_buf` is non-null, then `fields_buf_len` and `fields_total` must also be non-null to store the buffer length used and total field count. #[no_mangle] pub unsafe extern "C" fn expression_validate( @@ -187,51 +181,40 @@ pub unsafe extern "C" fn expression_validate( let predicates = ast.get_predicates(); // Get used fields - if !fields_buf.is_null() { - assert!( - !(fields_buf_len.is_null() || fields_total.is_null()), - "fields_buf_len and fields_total must be non-null when fields_buf is non-null" - ); - - let expr_fields = predicates - .iter() - .map(|p| p.lhs.var_name.as_str()) - .collect::>(); - let total_fields_length = expr_fields - .iter() - .map(|k| k.as_bytes().len() + 1) - .sum::(); - - if !fields_buf.is_null() { - if *fields_buf_len < total_fields_length { - *fields_buf_len = total_fields_length; - *fields_total = expr_fields.len(); - return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; - } - - let mut fields_buf_ptr = fields_buf; - for field in &expr_fields { - let field = ffi::CString::new(*field).unwrap(); - let field_slice = field.as_bytes_with_nul(); - let field_len = field_slice.len(); - let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); - fields_buf.copy_from_slice(field_slice); - fields_buf_ptr = fields_buf_ptr.add(field_len); - } - } - + let expr_fields = predicates + .iter() + .map(|p| p.lhs.var_name.as_str()) + .collect::>(); + let total_fields_length = expr_fields + .iter() + .map(|k| k.as_bytes().len() + 1) + .sum::(); + + if *fields_buf_len < total_fields_length { *fields_buf_len = total_fields_length; *fields_total = expr_fields.len(); + return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; + } + + let mut fields_buf_ptr = fields_buf; + for field in &expr_fields { + let field = ffi::CString::new(*field).unwrap(); + let field_slice = field.as_bytes_with_nul(); + let field_len = field_slice.len(); + let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); + fields_buf.copy_from_slice(field_slice); + fields_buf_ptr = fields_buf_ptr.add(field_len); } + *fields_buf_len = total_fields_length; + *fields_total = expr_fields.len(); + // Get used operators - if !operators.is_null() { - let mut ops = BinaryOperatorFlags::empty(); - for pred in &predicates { - ops |= BinaryOperatorFlags::from(&pred.op); - } - *operators = ops.bits(); + let mut ops = BinaryOperatorFlags::empty(); + for pred in &predicates { + ops |= BinaryOperatorFlags::from(&pred.op); } + *operators = ops.bits(); ATC_ROUTER_EXPRESSION_VALIDATE_OK } From 3788e9275b2394e6c516df8e54feba7463a0eeee Mon Sep 17 00:00:00 2001 From: Shiroko Date: Tue, 5 Nov 2024 16:35:06 +0800 Subject: [PATCH 19/23] feat(expressions-compatibility): use iterator and early stop --- src/ffi/expression.rs | 123 +++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 61 deletions(-) diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index de6d49e4..c03cdc7f 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -7,37 +7,43 @@ use std::ffi; use std::os::raw::c_char; use std::slice::from_raw_parts_mut; -impl Expression { - fn get_predicates(&self) -> Vec<&Predicate> { - let mut predicates = Vec::new(); +use std::iter::Iterator; + +struct PredicateIterator<'a> { + stack: Vec<&'a Expression>, +} + +impl<'a> PredicateIterator<'a> { + fn new(expr: &'a Expression) -> Self { + Self { stack: vec![expr] } + } +} - fn visit<'a, 'b>(expr: &'a Expression, predicates: &mut Vec<&'b Predicate>) - where - 'a: 'b, - { +impl<'a> Iterator for PredicateIterator<'a> { + type Item = &'a Predicate; + + fn next(&mut self) -> Option { + while let Some(expr) = self.stack.pop() { match expr { Expression::Logical(l) => match l.as_ref() { - LogicalExpression::And(l, r) => { - visit(l, predicates); - visit(r, predicates); - } - LogicalExpression::Or(l, r) => { - visit(l, predicates); - visit(r, predicates); + LogicalExpression::And(l, r) | LogicalExpression::Or(l, r) => { + self.stack.push(l); + self.stack.push(r); } LogicalExpression::Not(r) => { - visit(r, predicates); + self.stack.push(r); } }, - Expression::Predicate(p) => { - predicates.push(p); - } + Expression::Predicate(p) => return Some(p), } } + None + } +} - visit(self, &mut predicates); - - predicates +impl Expression { + fn iter_predicates(&self) -> PredicateIterator { + PredicateIterator::new(self) } } @@ -119,8 +125,7 @@ pub const ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL: i64 = 2; /// If `fields_buf_len` indicates that `fields_buf` is sufficient, this function writes the used fields to `fields_buf`, each field terminated by `\0`. /// It updates `fields_buf_len` with the required buffer length and stores the total number of fields in `fields_total`. /// -/// If `fields_buf_len` indicates that `fields_buf` is insufficient, it writes the required buffer length to `fields_buf_len` -/// and the total number of fields to `fields_total`, then returns `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. +/// If `fields_buf_len` indicates that `fields_buf` is insufficient, it returns `ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL`. /// /// It writes the used operators as bitflags to `operators`. /// Bitflags are defined by `BinaryOperatorFlags` and must exclude bits from `BinaryOperatorFlags::UNUSED`. @@ -177,43 +182,38 @@ pub unsafe extern "C" fn expression_validate( return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; } - // Direct use GetPredicates trait to avoid unnecessary accesses - let predicates = ast.get_predicates(); - - // Get used fields - let expr_fields = predicates - .iter() - .map(|p| p.lhs.var_name.as_str()) - .collect::>(); - let total_fields_length = expr_fields - .iter() - .map(|k| k.as_bytes().len() + 1) - .sum::(); - - if *fields_buf_len < total_fields_length { - *fields_buf_len = total_fields_length; - *fields_total = expr_fields.len(); - return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; - } - + // Iterate over predicates to get fields and operators + let mut ops = BinaryOperatorFlags::empty(); + let mut existed_fields = HashSet::new(); + let mut total_fields_length = 0; let mut fields_buf_ptr = fields_buf; - for field in &expr_fields { - let field = ffi::CString::new(*field).unwrap(); - let field_slice = field.as_bytes_with_nul(); - let field_len = field_slice.len(); - let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); - fields_buf.copy_from_slice(field_slice); - fields_buf_ptr = fields_buf_ptr.add(field_len); - } - - *fields_buf_len = total_fields_length; - *fields_total = expr_fields.len(); + *fields_total = 0; - // Get used operators - let mut ops = BinaryOperatorFlags::empty(); - for pred in &predicates { + for pred in ast.iter_predicates() { ops |= BinaryOperatorFlags::from(&pred.op); + + let field = pred.lhs.var_name.as_str(); + + if existed_fields.insert(field) { + // Fields is not existed yet. + let field = ffi::CString::new(field).unwrap(); + let field_slice = field.as_bytes_with_nul(); + let field_len = field_slice.len(); + + *fields_total += 1; + total_fields_length += field_len; + + if *fields_buf_len < total_fields_length { + return ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL; + } + + let fields_buf = from_raw_parts_mut(fields_buf_ptr, field_len); + fields_buf.copy_from_slice(field_slice); + fields_buf_ptr = fields_buf_ptr.add(field_len); + } } + + *fields_buf_len = total_fields_length; *operators = ops.bits(); ATC_ROUTER_EXPRESSION_VALIDATE_OK @@ -228,7 +228,7 @@ mod tests { schema: &Schema, atc: &str, fields_buf_size: usize, - ) -> Result<(Vec, u64), (i64, String)> { + ) -> Result<(Vec, usize, u64), (i64, String)> { let atc = ffi::CString::new(atc).unwrap(); let mut errbuf = vec![b'X'; ERR_BUF_MAX_LEN]; let mut errbuf_len = ERR_BUF_MAX_LEN; @@ -263,7 +263,7 @@ mod tests { } assert_eq!(fields_buf_len, p, "Fields buffer length mismatch"); fields.sort(); - Ok((fields, operators)) + Ok((fields, fields_buf_len, operators)) } ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => { let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); @@ -284,10 +284,10 @@ mod tests { schema.add_field("net.src.ip", Type::IpAddr); schema.add_field("http.path", Type::String); - let result = expr_validate_on(&schema, atc, 1024); + let result = expr_validate_on(&schema, atc, 47); assert!(result.is_ok(), "Validation failed"); - let (fields, ops) = result.unwrap(); // Unwrap is safe since we've already asserted it + let (fields, fields_buf_len, ops) = result.unwrap(); // Unwrap is safe since we've already asserted it assert_eq!( ops, (BinaryOperatorFlags::EQUALS @@ -308,6 +308,7 @@ mod tests { ], "Fields mismatch" ); + assert_eq!(fields_buf_len, 47, "Fields buffer length mismatch"); } #[test] @@ -369,7 +370,7 @@ mod tests { schema.add_field("net.src.ip", Type::IpAddr); schema.add_field("http.path", Type::String); - let result = expr_validate_on(&schema, atc, 10); + let result = expr_validate_on(&schema, atc, 46); assert!(result.is_err(), "Validation failed"); let (err_code, _) = result.unwrap_err(); // Unwrap is safe since we've already asserted it From 18a919488af6811e0d7a150302e9e6b7527155fa Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 6 Nov 2024 13:54:04 +0800 Subject: [PATCH 20/23] feat(expressions-compatibility): add upwrap safety statement --- src/ffi/expression.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index c03cdc7f..50ff5ccf 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -196,6 +196,7 @@ pub unsafe extern "C" fn expression_validate( if existed_fields.insert(field) { // Fields is not existed yet. + // Unwrap is safe since `field` cannot contain '\0' as `atc` must not contain any internal `\0`. let field = ffi::CString::new(field).unwrap(); let field_slice = field.as_bytes_with_nul(); let field_len = field_slice.len(); From 9fd78c7c54a17328c1c195f7b904a44275b1ee95 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 6 Nov 2024 14:48:37 +0800 Subject: [PATCH 21/23] feat(ffi): improve error buf writing --- src/ffi/context.rs | 8 ++------ src/ffi/expression.rs | 21 ++++++++------------- src/ffi/mod.rs | 31 +++++++++++++++++++++++++++++++ src/ffi/router.rs | 9 +++------ 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/ffi/context.rs b/src/ffi/context.rs index bbeab135..75fa4737 100644 --- a/src/ffi/context.rs +++ b/src/ffi/context.rs @@ -1,8 +1,7 @@ use crate::ast::Value; use crate::context::Context; -use crate::ffi::{CValue, ERR_BUF_MAX_LEN}; +use crate::ffi::{CValue, write_errbuf}; use crate::schema::Schema; -use std::cmp::min; use std::ffi; use std::os::raw::c_char; use std::slice::from_raw_parts_mut; @@ -89,13 +88,10 @@ pub unsafe extern "C" fn context_add_value( let field = ffi::CStr::from_ptr(field as *const c_char) .to_str() .unwrap(); - let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); let value: Result = value.try_into(); if let Err(e) = value { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; + write_errbuf(e, errbuf, errbuf_len); return false; } diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index 50ff5ccf..57ac67bb 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -1,8 +1,7 @@ use crate::ast::{BinaryOperator, Expression, LogicalExpression, Predicate}; -use crate::ffi::ERR_BUF_MAX_LEN; +use crate::ffi::write_errbuf; use crate::schema::Schema; use bitflags::bitflags; -use std::cmp::min; use std::ffi; use std::os::raw::c_char; use std::slice::from_raw_parts_mut; @@ -161,24 +160,19 @@ pub unsafe extern "C" fn expression_validate( use crate::semantics::Validate; let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); - let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); // Parse the expression - let result = parse(atc).map_err(|e| e.to_string()); + let result = parse(atc); if let Err(e) = result { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; + write_errbuf(e, errbuf, errbuf_len); return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; } // Unwrap is safe since we've already checked for error let ast = result.unwrap(); // Validate expression with schema - if let Err(e) = ast.validate(schema).map_err(|e| e.to_string()) { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; + if let Err(e) = ast.validate(schema) { + write_errbuf(e, errbuf, errbuf_len); return ATC_ROUTER_EXPRESSION_VALIDATE_FAILED; } @@ -224,6 +218,7 @@ pub unsafe extern "C" fn expression_validate( mod tests { use super::*; use crate::ast::Type; + use crate::ffi::ERR_BUF_MAX_LEN; fn expr_validate_on( schema: &Schema, @@ -267,8 +262,8 @@ mod tests { Ok((fields, fields_buf_len, operators)) } ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => { - let err = String::from_utf8(errbuf[..errbuf_len].to_vec()).unwrap(); - Err((result, err)) + let err = ffi::CStr::from_bytes_with_nul(&errbuf[..errbuf_len]).expect("error message is not null-terminated"); + Err((result, err.to_string_lossy().to_string())) } ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())), _ => panic!("Unknown error code"), diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index c7d01b1a..55047efb 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -5,11 +5,14 @@ pub mod schema; use crate::ast::Value; use cidr::IpCidr; +use std::cmp::min; use std::convert::TryFrom; use std::ffi; +use std::fmt::Display; use std::net::IpAddr; use std::os::raw::c_char; use std::slice::from_raw_parts; +use std::slice::from_raw_parts_mut; pub const ERR_BUF_MAX_LEN: usize = 4096; @@ -56,3 +59,31 @@ impl TryFrom<&CValue> for Value { }) } } + +/// Write displayable error message to the error buffer. +/// +/// # Arguments +/// +/// - `err`: the displayable error message. +/// - `errbuf`: a buffer to store the error message. +/// - `errbuf_len`: a pointer to the length of the error message buffer. +/// +/// # Safety +/// +/// Violating any of the following constraints will result in undefined behavior: +/// +/// * `errbuf` must be valid to read and write for `errbuf_len * size_of::()` bytes, +/// and it must be properly aligned. +/// * `errbuf_len` must be vlaid to read and write for `size_of::()` bytes, +/// and it must be properly aligned. +unsafe fn write_errbuf(err: impl Display, errbuf: *mut u8, errbuf_len: *mut usize) { + let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); + // Replace internal '\0' to space. + let err = err.to_string().replace('\0', " "); + // Unwrap is safe since we already remove all internal '\0's. + let err_cstring = std::ffi::CString::new(err.to_string()).unwrap(); + let err_bytes = err_cstring.as_bytes_with_nul(); + let errlen = min(err_bytes.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&err_bytes[..errlen]); + *errbuf_len = errlen; +} diff --git a/src/ffi/router.rs b/src/ffi/router.rs index c7e9f65c..6043c2b3 100644 --- a/src/ffi/router.rs +++ b/src/ffi/router.rs @@ -1,8 +1,7 @@ use crate::context::Context; -use crate::ffi::ERR_BUF_MAX_LEN; +use crate::ffi::write_errbuf; use crate::router::Router; use crate::schema::Schema; -use std::cmp::min; use std::ffi; use std::os::raw::c_char; use std::slice::from_raw_parts_mut; @@ -98,14 +97,11 @@ pub unsafe extern "C" fn router_add_matcher( ) -> bool { let uuid = ffi::CStr::from_ptr(uuid as *const c_char).to_str().unwrap(); let atc = ffi::CStr::from_ptr(atc as *const c_char).to_str().unwrap(); - let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); let uuid = Uuid::try_parse(uuid).expect("invalid UUID format"); if let Err(e) = router.add_matcher(priority, uuid, atc) { - let errlen = min(e.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&e.as_bytes()[..errlen]); - *errbuf_len = errlen; + write_errbuf(e, errbuf, errbuf_len); return false; } @@ -239,6 +235,7 @@ pub unsafe extern "C" fn router_get_fields( #[cfg(test)] mod tests { use super::*; + use crate::ffi::ERR_BUF_MAX_LEN; #[test] fn test_long_error_message() { From 9296af6f5c5f5b4872a20b1fd8611476e943bd88 Mon Sep 17 00:00:00 2001 From: Shiroko Date: Wed, 6 Nov 2024 16:31:07 +0800 Subject: [PATCH 22/23] chore(*): cargo fmt fix --- src/ffi/context.rs | 2 +- src/ffi/expression.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ffi/context.rs b/src/ffi/context.rs index 75fa4737..af6083a6 100644 --- a/src/ffi/context.rs +++ b/src/ffi/context.rs @@ -1,6 +1,6 @@ use crate::ast::Value; use crate::context::Context; -use crate::ffi::{CValue, write_errbuf}; +use crate::ffi::{write_errbuf, CValue}; use crate::schema::Schema; use std::ffi; use std::os::raw::c_char; diff --git a/src/ffi/expression.rs b/src/ffi/expression.rs index 57ac67bb..f9b83aa1 100644 --- a/src/ffi/expression.rs +++ b/src/ffi/expression.rs @@ -262,7 +262,8 @@ mod tests { Ok((fields, fields_buf_len, operators)) } ATC_ROUTER_EXPRESSION_VALIDATE_FAILED => { - let err = ffi::CStr::from_bytes_with_nul(&errbuf[..errbuf_len]).expect("error message is not null-terminated"); + let err = ffi::CStr::from_bytes_with_nul(&errbuf[..errbuf_len]) + .expect("error message is not null-terminated"); Err((result, err.to_string_lossy().to_string())) } ATC_ROUTER_EXPRESSION_VALIDATE_BUF_TOO_SMALL => Err((result, String::new())), From 6bd491e574398ad45e56703efc16fbd3034ac321 Mon Sep 17 00:00:00 2001 From: Haoxuan Date: Tue, 12 Nov 2024 16:48:12 +0800 Subject: [PATCH 23/23] chore(ffi): remove C style string conversion --- src/ffi/mod.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 55047efb..879f96d2 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -79,11 +79,8 @@ impl TryFrom<&CValue> for Value { unsafe fn write_errbuf(err: impl Display, errbuf: *mut u8, errbuf_len: *mut usize) { let errbuf = from_raw_parts_mut(errbuf, ERR_BUF_MAX_LEN); // Replace internal '\0' to space. - let err = err.to_string().replace('\0', " "); - // Unwrap is safe since we already remove all internal '\0's. - let err_cstring = std::ffi::CString::new(err.to_string()).unwrap(); - let err_bytes = err_cstring.as_bytes_with_nul(); - let errlen = min(err_bytes.len(), *errbuf_len); - errbuf[..errlen].copy_from_slice(&err_bytes[..errlen]); + let err = err.to_string(); + let errlen = min(err.len(), *errbuf_len); + errbuf[..errlen].copy_from_slice(&err.as_bytes()[..errlen]); *errbuf_len = errlen; }