Skip to content

Commit

Permalink
Pass length to free register list callback
Browse files Browse the repository at this point in the history
Allows language bindings like rust to free register lists sanely
  • Loading branch information
emesare committed Dec 14, 2024
1 parent 487fa4b commit 121c165
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 93 deletions.
2 changes: 1 addition & 1 deletion architecture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ size_t Architecture::GetSemanticFlagGroupLowLevelILCallback(void* ctxt, uint32_t
}


void Architecture::FreeRegisterListCallback(void*, uint32_t* regs)
void Architecture::FreeRegisterListCallback(void*, uint32_t* regs, size_t)
{
delete[] regs;
}
Expand Down
4 changes: 2 additions & 2 deletions binaryninjaapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -7940,7 +7940,7 @@ namespace BinaryNinja {
static size_t GetFlagConditionLowLevelILCallback(
void* ctxt, BNLowLevelILFlagCondition cond, uint32_t semClass, BNLowLevelILFunction* il);
static size_t GetSemanticFlagGroupLowLevelILCallback(void* ctxt, uint32_t semGroup, BNLowLevelILFunction* il);
static void FreeRegisterListCallback(void* ctxt, uint32_t* regs);
static void FreeRegisterListCallback(void* ctxt, uint32_t* regs, size_t len);
static void GetRegisterInfoCallback(void* ctxt, uint32_t reg, BNRegisterInfo* result);
static uint32_t GetStackPointerRegisterCallback(void* ctxt);
static uint32_t GetLinkRegisterCallback(void* ctxt);
Expand Down Expand Up @@ -15026,7 +15026,7 @@ namespace BinaryNinja {
static uint32_t* GetCalleeSavedRegistersCallback(void* ctxt, size_t* count);
static uint32_t* GetIntegerArgumentRegistersCallback(void* ctxt, size_t* count);
static uint32_t* GetFloatArgumentRegistersCallback(void* ctxt, size_t* count);
static void FreeRegisterListCallback(void* ctxt, uint32_t* regs);
static void FreeRegisterListCallback(void* ctxt, uint32_t* regs, size_t len);

static bool AreArgumentRegistersSharedIndexCallback(void* ctxt);
static bool AreArgumentRegistersUsedForVarArgsCallback(void* ctxt);
Expand Down
4 changes: 2 additions & 2 deletions binaryninjacore.h
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ extern "C"
size_t (*getFlagConditionLowLevelIL)(
void* ctxt, BNLowLevelILFlagCondition cond, uint32_t semClass, BNLowLevelILFunction* il);
size_t (*getSemanticFlagGroupLowLevelIL)(void* ctxt, uint32_t semGroup, BNLowLevelILFunction* il);
void (*freeRegisterList)(void* ctxt, uint32_t* regs);
void (*freeRegisterList)(void* ctxt, uint32_t* regs, size_t count);
void (*getRegisterInfo)(void* ctxt, uint32_t reg, BNRegisterInfo* result);
uint32_t (*getStackPointerRegister)(void* ctxt);
uint32_t (*getLinkRegister)(void* ctxt);
Expand Down Expand Up @@ -2541,7 +2541,7 @@ extern "C"
uint32_t* (*getCalleeSavedRegisters)(void* ctxt, size_t* count);
uint32_t* (*getIntegerArgumentRegisters)(void* ctxt, size_t* count);
uint32_t* (*getFloatArgumentRegisters)(void* ctxt, size_t* count);
void (*freeRegisterList)(void* ctxt, uint32_t* regs);
void (*freeRegisterList)(void* ctxt, uint32_t* regs, size_t len);

bool (*areArgumentRegistersSharedIndex)(void* ctxt);
bool (*isStackReservedForArgumentRegisters)(void* ctxt);
Expand Down
2 changes: 1 addition & 1 deletion callingconvention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ uint32_t* CallingConvention::GetFloatArgumentRegistersCallback(void* ctxt, size_
}


void CallingConvention::FreeRegisterListCallback(void*, uint32_t* regs)
void CallingConvention::FreeRegisterListCallback(void*, uint32_t* regs, size_t)
{
delete[] regs;
}
Expand Down
2 changes: 1 addition & 1 deletion python/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def _get_semantic_flag_group_low_level_il(self, ctxt, sem_group, il):
log_error(traceback.format_exc())
return 0

def _free_register_list(self, ctxt, regs):
def _free_register_list(self, ctxt, regs, count):
try:
buf = ctypes.cast(regs, ctypes.c_void_p)
if buf.value not in self._pending_reg_lists:
Expand Down
2 changes: 1 addition & 1 deletion python/callingconvention.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _get_float_arg_regs(self, ctxt, count):
count[0] = 0
return None

def _free_register_list(self, ctxt, regs):
def _free_register_list(self, ctxt, regs, count):
try:
buf = ctypes.cast(regs, ctypes.c_void_p)
if buf.value not in self._pending_reg_lists:
Expand Down
134 changes: 84 additions & 50 deletions rust/src/architecture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1941,104 +1941,117 @@ where
None => BnString::new("invalid_flag_group").into_raw(),
}
}

fn alloc_register_list<I: Iterator<Item = u32> + ExactSizeIterator>(
items: I,
count: &mut usize,
) -> *mut u32 {
let len = items.len();
*count = len;

if len == 0 {
ptr::null_mut()
} else {
let mut res: Box<[_]> = [len as u32].into_iter().chain(items).collect();

let raw = res.as_mut_ptr();
mem::forget(res);

unsafe { raw.offset(1) }
}
}


extern "C" fn cb_registers_full_width<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let regs = custom_arch.registers_full_width();
let mut regs = custom_arch.registers_full_width();

alloc_register_list(regs.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = regs.len() };
let regs_ptr = regs.as_mut_ptr();
mem::forget(regs);
regs_ptr as *mut _
}

extern "C" fn cb_registers_all<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let regs = custom_arch.registers_all();
let mut regs = custom_arch.registers_all();

alloc_register_list(regs.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = regs.len() };
let regs_ptr = regs.as_mut_ptr();
mem::forget(regs);
regs_ptr as *mut _
}

extern "C" fn cb_registers_global<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let regs = custom_arch.registers_global();
let mut regs = custom_arch.registers_global();

alloc_register_list(regs.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = regs.len() };
let regs_ptr = regs.as_mut_ptr();
mem::forget(regs);
regs_ptr as *mut _
}

extern "C" fn cb_registers_system<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let regs = custom_arch.registers_system();
let mut regs = custom_arch.registers_system();

alloc_register_list(regs.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = regs.len() };
let regs_ptr = regs.as_mut_ptr();
mem::forget(regs);
regs_ptr as *mut _
}

extern "C" fn cb_flags<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let flags = custom_arch.flags();
let mut flags = custom_arch.flags();

alloc_register_list(flags.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = flags.len() };
let regs_ptr = flags.as_mut_ptr();
mem::forget(flags);
regs_ptr as *mut _
}

extern "C" fn cb_flag_write_types<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let flag_writes = custom_arch.flag_write_types();
let mut flag_writes = custom_arch.flag_write_types();

alloc_register_list(flag_writes.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = flag_writes.len() };
let regs_ptr = flag_writes.as_mut_ptr();
mem::forget(flag_writes);
regs_ptr as *mut _
}

extern "C" fn cb_semantic_flag_classes<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let flag_classes = custom_arch.flag_classes();
let mut flag_classes = custom_arch.flag_classes();

alloc_register_list(flag_classes.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = flag_classes.len() };
let regs_ptr = flag_classes.as_mut_ptr();
mem::forget(flag_classes);
regs_ptr as *mut _
}

extern "C" fn cb_semantic_flag_groups<A>(ctxt: *mut c_void, count: *mut usize) -> *mut u32
where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let flag_groups = custom_arch.flag_groups();
let mut flag_groups = custom_arch.flag_groups();

alloc_register_list(flag_groups.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = flag_groups.len() };
let regs_ptr = flag_groups.as_mut_ptr();
mem::forget(flag_groups);
regs_ptr as *mut _
}

extern "C" fn cb_flag_role<A>(ctxt: *mut c_void, flag: u32, class: u32) -> BNFlagRole
Expand Down Expand Up @@ -2068,9 +2081,13 @@ where
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let class = custom_arch.flag_class_from_id(class);
let flags = custom_arch.flags_required_for_flag_condition(cond, class);
let mut flags = custom_arch.flags_required_for_flag_condition(cond, class);

alloc_register_list(flags.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: `count` is an out parameter
unsafe { *count = flags.len() };
let regs_ptr = flags.as_mut_ptr();
mem::forget(flags);
regs_ptr as *mut _
}

extern "C" fn cb_flags_required_for_semantic_flag_group<A>(
Expand All @@ -2084,8 +2101,13 @@ where
let custom_arch = unsafe { &*(ctxt as *mut A) };

if let Some(group) = custom_arch.flag_group_from_id(group) {
let flags = group.flags_required();
alloc_register_list(flags.iter().map(|r| r.id()), unsafe { &mut *count })
let mut flags = group.flags_required();

// SAFETY: `count` is an out parameter
unsafe { *count = flags.len() };
let regs_ptr = flags.as_mut_ptr();
mem::forget(flags);
regs_ptr as *mut _
} else {
unsafe {
*count = 0;
Expand Down Expand Up @@ -2153,8 +2175,13 @@ where
let custom_arch = unsafe { &*(ctxt as *mut A) };

if let Some(write_type) = custom_arch.flag_write_from_id(write_type) {
let written = write_type.flags_written();
alloc_register_list(written.iter().map(|f| f.id()), unsafe { &mut *count })
let mut written = write_type.flags_written();

// SAFETY: `count` is an out parameter
unsafe { *count = written.len() };
let regs_ptr = written.as_mut_ptr();
mem::forget(written);
regs_ptr as *mut _
} else {
unsafe {
*count = 0;
Expand Down Expand Up @@ -2285,15 +2312,13 @@ where
lifter.unimplemented().expr_idx
}

extern "C" fn cb_free_register_list(_ctxt: *mut c_void, regs: *mut u32) {
extern "C" fn cb_free_register_list(_ctxt: *mut c_void, regs: *mut u32, count: usize) {
if regs.is_null() {
return;
}

unsafe {
let actual_start = regs.offset(-1);
let len = *actual_start + 1;
let regs_ptr = ptr::slice_from_raw_parts_mut(actual_start, len.try_into().unwrap());
let regs_ptr = ptr::slice_from_raw_parts_mut(regs, count);
let _regs = Box::from_raw(regs_ptr);
}
}
Expand Down Expand Up @@ -2362,9 +2387,13 @@ where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let regs = custom_arch.register_stacks();
let mut regs = custom_arch.register_stacks();

alloc_register_list(regs.iter().map(|r| r.id()), unsafe { &mut *count })
// SAFETY: Passed in to be written
unsafe { *count = regs.len() };
let regs_ptr = regs.as_mut_ptr();
mem::forget(regs);
regs_ptr as *mut _
}

extern "C" fn cb_reg_stack_info<A>(
Expand Down Expand Up @@ -2420,8 +2449,13 @@ where
A: 'static + Architecture<Handle = CustomArchitectureHandle<A>> + Send + Sync,
{
let custom_arch = unsafe { &*(ctxt as *mut A) };
let intrinsics = custom_arch.intrinsics();
alloc_register_list(intrinsics.iter().map(|i| i.id()), unsafe { &mut *count })
let mut intrinsics = custom_arch.intrinsics();

// SAFETY: Passed in to be written
unsafe { *count = intrinsics.len() };
let regs_ptr = intrinsics.as_mut_ptr();
mem::forget(intrinsics);
regs_ptr as *mut _
}

extern "C" fn cb_intrinsic_inputs<A>(
Expand Down
Loading

0 comments on commit 121c165

Please sign in to comment.