From 793361003ab4ddd2c3f42dc620f2a7994d5e9c25 Mon Sep 17 00:00:00 2001 From: Emanuele Giaquinta Date: Wed, 11 Dec 2024 21:50:23 +0200 Subject: [PATCH] Refactor parsing of packb and unpackb arguments Signed-off-by: Emanuele Giaquinta --- src/lib.rs | 123 +++++++++++++++++++++------------------------- src/opt.rs | 4 +- tests/test_api.py | 70 +++++++++++++------------- 3 files changed, 92 insertions(+), 105 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2ee29946..09c4c070 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,6 @@ mod typeref; mod unicode; use pyo3::ffi::*; -use std::borrow::Cow; use std::os::raw::c_char; use std::os::raw::c_int; use std::os::raw::c_long; @@ -143,8 +142,7 @@ pub unsafe extern "C" fn ormsgpack_exec(mptr: *mut PyObject) -> c_int { #[cold] #[inline(never)] -fn raise_unpackb_exception(err: deserialize::DeserializeError) -> *mut PyObject { - let msg = err.message; +fn raise_unpackb_exception(msg: &str) -> *mut PyObject { unsafe { let err_msg = PyUnicode_FromStringAndSize(msg.as_ptr() as *const c_char, msg.len() as isize); @@ -158,7 +156,7 @@ fn raise_unpackb_exception(err: deserialize::DeserializeError) -> *mut PyObject #[cold] #[inline(never)] -fn raise_packb_exception(msg: Cow) -> *mut PyObject { +fn raise_packb_exception(msg: &str) -> *mut PyObject { unsafe { let err_msg = PyUnicode_FromStringAndSize(msg.as_ptr() as *const c_char, msg.len() as isize); @@ -168,6 +166,21 @@ fn raise_packb_exception(msg: Cow) -> *mut PyObject { std::ptr::null_mut() } +unsafe fn parse_option_arg(opts: *mut PyObject, mask: i32) -> Result { + if Py_TYPE(opts) == typeref::INT_TYPE { + let val = PyLong_AsLong(opts) as i32; + if val & !mask == 0 { + Ok(val) + } else { + Err(()) + } + } else if opts == typeref::NONE { + Ok(0) + } else { + Err(()) + } +} + #[no_mangle] pub unsafe extern "C" fn unpackb( _self: *mut PyObject, @@ -181,50 +194,37 @@ pub unsafe extern "C" fn unpackb( let num_args = PyVectorcall_NARGS(nargs as usize); if unlikely!(num_args != 1) { let msg = if num_args > 1 { - Cow::Borrowed("unpackb() accepts only 1 positional argument") + "unpackb() accepts only 1 positional argument" } else { - Cow::Borrowed("unpackb() missing 1 required positional argument: 'obj'") + "unpackb() missing 1 required positional argument: 'obj'" }; - return raise_unpackb_exception(deserialize::DeserializeError::new(msg)); + return raise_unpackb_exception(msg); } if !kwnames.is_null() { let tuple_size = PyTuple_GET_SIZE(kwnames); - if tuple_size > 0 { - for i in 0..=tuple_size - 1 { - let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t); - if arg == typeref::EXT_HOOK { - ext_hook = Some(NonNull::new_unchecked(*args.offset(num_args + i))); - } else if arg == typeref::OPTION { - optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i))); - } else { - return raise_unpackb_exception(deserialize::DeserializeError::new( - Cow::Borrowed("unpackb() got an unexpected keyword argument"), - )); - } + for i in 0..tuple_size { + let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t); + if arg == typeref::EXT_HOOK { + ext_hook = Some(NonNull::new_unchecked(*args.offset(num_args + i))); + } else if arg == typeref::OPTION { + optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i))); + } else { + return raise_unpackb_exception("unpackb() got an unexpected keyword argument"); } } } let mut optsbits: i32 = 0; if let Some(opts) = optsptr { - let ob_type = (*opts.as_ptr()).ob_type; - if ob_type == typeref::INT_TYPE { - optsbits = PyLong_AsLong(optsptr.unwrap().as_ptr()) as i32; - if !(0..=opt::MAX_UNPACKB_OPT).contains(&optsbits) { - return raise_unpackb_exception(deserialize::DeserializeError::new(Cow::Borrowed( - "Invalid opts", - ))); - } - } else if ob_type != typeref::NONE_TYPE { - return raise_unpackb_exception(deserialize::DeserializeError::new(Cow::Borrowed( - "Invalid opts", - ))); + match parse_option_arg(opts.as_ptr(), opt::UNPACKB_OPT_MASK) { + Ok(val) => optsbits = val, + Err(()) => return raise_unpackb_exception("Invalid opts"), } } match crate::deserialize::deserialize(*args, ext_hook, optsbits as opt::Opt) { Ok(val) => val.as_ptr(), - Err(err) => raise_unpackb_exception(err), + Err(err) => raise_unpackb_exception(&err.message), } } @@ -240,59 +240,48 @@ pub unsafe extern "C" fn packb( let num_args = PyVectorcall_NARGS(nargs as usize); if unlikely!(num_args == 0) { - return raise_packb_exception(Cow::Borrowed( - "packb() missing 1 required positional argument: 'obj'", - )); + return raise_packb_exception("packb() missing 1 required positional argument: 'obj'"); } - if num_args & 2 == 2 { + if num_args >= 2 { default = Some(NonNull::new_unchecked(*args.offset(1))); } - if num_args & 3 == 3 { + if num_args >= 3 { optsptr = Some(NonNull::new_unchecked(*args.offset(2))); } if !kwnames.is_null() { let tuple_size = PyTuple_GET_SIZE(kwnames); - if tuple_size > 0 { - for i in 0..=tuple_size - 1 { - let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t); - if arg == typeref::DEFAULT { - if unlikely!(num_args & 2 == 2) { - return raise_packb_exception(Cow::Borrowed( - "packb() got multiple values for argument: 'default'", - )); - } - default = Some(NonNull::new_unchecked(*args.offset(num_args + i))); - } else if arg == typeref::OPTION { - if unlikely!(num_args & 3 == 3) { - return raise_packb_exception(Cow::Borrowed( - "packb() got multiple values for argument: 'option'", - )); - } - optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i))); - } else { - return raise_packb_exception(Cow::Borrowed( - "packb() got an unexpected keyword argument", - )); + for i in 0..tuple_size { + let arg = PyTuple_GET_ITEM(kwnames, i as Py_ssize_t); + if arg == typeref::DEFAULT { + if unlikely!(default.is_some()) { + return raise_packb_exception( + "packb() got multiple values for argument: 'default'", + ); } + default = Some(NonNull::new_unchecked(*args.offset(num_args + i))); + } else if arg == typeref::OPTION { + if unlikely!(optsptr.is_some()) { + return raise_packb_exception( + "packb() got multiple values for argument: 'option'", + ); + } + optsptr = Some(NonNull::new_unchecked(*args.offset(num_args + i))); + } else { + return raise_packb_exception("packb() got an unexpected keyword argument"); } } } let mut optsbits: i32 = 0; if let Some(opts) = optsptr { - let ob_type = (*opts.as_ptr()).ob_type; - if ob_type == typeref::INT_TYPE { - optsbits = PyLong_AsLong(optsptr.unwrap().as_ptr()) as i32; - if !(0..=opt::MAX_PACKB_OPT).contains(&optsbits) { - return raise_packb_exception(Cow::Borrowed("Invalid opts")); - } - } else if ob_type != typeref::NONE_TYPE { - return raise_packb_exception(Cow::Borrowed("Invalid opts")); + match parse_option_arg(opts.as_ptr(), opt::PACKB_OPT_MASK) { + Ok(val) => optsbits = val, + Err(()) => return raise_packb_exception("Invalid opts"), } } match crate::serialize::serialize(*args, default, optsbits as opt::Opt) { Ok(val) => val.as_ptr(), - Err(err) => raise_packb_exception(Cow::Borrowed(&err)), + Err(err) => raise_packb_exception(&err), } } diff --git a/src/opt.rs b/src/opt.rs index 4d97a412..57f88014 100644 --- a/src/opt.rs +++ b/src/opt.rs @@ -21,7 +21,7 @@ pub const NOT_PASSTHROUGH: Opt = !(PASSTHROUGH_BIG_INT | PASSTHROUGH_SUBCLASS | PASSTHROUGH_TUPLE); -pub const MAX_PACKB_OPT: i32 = (NAIVE_UTC +pub const PACKB_OPT_MASK: i32 = (NAIVE_UTC | NON_STR_KEYS | OMIT_MICROSECONDS | PASSTHROUGH_BIG_INT @@ -34,4 +34,4 @@ pub const MAX_PACKB_OPT: i32 = (NAIVE_UTC | SORT_KEYS | UTC_Z) as i32; -pub const MAX_UNPACKB_OPT: i32 = NON_STR_KEYS as i32; +pub const UNPACKB_OPT_MASK: i32 = NON_STR_KEYS as i32; diff --git a/tests/test_api.py b/tests/test_api.py index bc46e639..8b8da401 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -82,44 +82,42 @@ def test_valueerror() -> None: ormsgpack.unpackb(b"\x91") -def test_option_not_int() -> None: - """ - packb/unpackb() option not int or None - """ - with pytest.raises(ormsgpack.MsgpackEncodeError): - ormsgpack.packb(True, option=True) - with pytest.raises(ormsgpack.MsgpackDecodeError): - ormsgpack.unpackb(b"\x00", option=True) - - -def test_option_invalid_int() -> None: - """ - packb/unpackb() option invalid 64-bit number - """ - with pytest.raises(ormsgpack.MsgpackEncodeError): - ormsgpack.packb(True, option=9223372036854775809) - with pytest.raises(ormsgpack.MsgpackDecodeError): - ormsgpack.unpackb(b"\x00", option=9223372036854775809) - - -def test_option_range_low() -> None: - """ - packb/unpackb() option out of range low - """ - with pytest.raises(ormsgpack.MsgpackEncodeError): - ormsgpack.packb(True, option=-1) - with pytest.raises(ormsgpack.MsgpackDecodeError): - ormsgpack.unpackb(b"\x00", option=-1) - - -def test_option_range_high() -> None: - """ - packb/unpackb() option out of range high - """ +@pytest.mark.parametrize( + "option", + ( + 1 << 12, + True, + -1, + 9223372036854775809, + ), +) +def test_packb_invalid_option(option: int) -> None: with pytest.raises(ormsgpack.MsgpackEncodeError): - ormsgpack.packb(True, option=1 << 14) + ormsgpack.packb(True, option=option) + + +@pytest.mark.parametrize( + "option", + ( + ormsgpack.OPT_NAIVE_UTC, + ormsgpack.OPT_OMIT_MICROSECONDS, + ormsgpack.OPT_PASSTHROUGH_BIG_INT, + ormsgpack.OPT_PASSTHROUGH_DATACLASS, + ormsgpack.OPT_PASSTHROUGH_DATETIME, + ormsgpack.OPT_PASSTHROUGH_SUBCLASS, + ormsgpack.OPT_PASSTHROUGH_TUPLE, + ormsgpack.OPT_SERIALIZE_NUMPY, + ormsgpack.OPT_SERIALIZE_PYDANTIC, + ormsgpack.OPT_SORT_KEYS, + ormsgpack.OPT_UTC_Z, + True, + -1, + 9223372036854775809, + ), +) +def test_unpackb_invalid_option(option: int) -> None: with pytest.raises(ormsgpack.MsgpackDecodeError): - ormsgpack.unpackb(b"\x00", option=1 << 14) + ormsgpack.unpackb(b"\x00", option=option) def test_opts_multiple() -> None: