diff --git a/runtime/memory-map/cbindgen.toml b/runtime/memory-map/cbindgen.toml index 08d115a0e..906a5cdbe 100644 --- a/runtime/memory-map/cbindgen.toml +++ b/runtime/memory-map/cbindgen.toml @@ -13,6 +13,7 @@ trailer = "/* clang-format on */" [export.rename] "MemoryMap" = "memory_map" +"PROT_INDETERMINATE" = "MEMORY_MAP_PROT_INDETERMINATE" "Range" = "range" # the below does not help because our types are opaque items, so they aren't diff --git a/runtime/memory-map/src/lib.rs b/runtime/memory-map/src/lib.rs index 94eecf6ed..5b0c4bed3 100644 --- a/runtime/memory-map/src/lib.rs +++ b/runtime/memory-map/src/lib.rs @@ -64,6 +64,8 @@ impl Range { pub struct State { pub owner_pkey: u8, pub pkey_mprotected: bool, + pub mprotected: bool, + pub prot: u32, } /// A contiguous region of the memory map whose state we track @@ -190,12 +192,18 @@ impl MemoryMap { Some(r.state) } - pub fn split_region(&mut self, subrange: Range, owner_pkey: u8) -> Option { + pub fn split_region( + &mut self, + subrange: Range, + owner_pkey: u8, + prot: u32, + ) -> Option { let state = self.split_out_region(subrange)?; // add the new split-off range let new_state = State { owner_pkey, + prot, ..state }; self.add_region(subrange, new_state); @@ -239,28 +247,74 @@ pub extern "C" fn memory_map_all_overlapping_regions_pkey_mprotected( pkey_mprotected: bool, ) -> bool { map.all_overlapping_regions(needle, |region| { + let same = pkey_mprotected == region.state.pkey_mprotected; printerrln!( - "does {:?} have pkey_mprotected=={}? =={}", + "does {:?} have pkey_mprotected=={}? {}", region.range, pkey_mprotected, - region.state.pkey_mprotected + if same { "yes" } else { "no" } ); - pkey_mprotected == region.state.pkey_mprotected + same }) } +#[no_mangle] +pub extern "C" fn memory_map_all_overlapping_regions_mprotected( + map: &MemoryMap, + needle: Range, + mprotected: bool, +) -> bool { + map.all_overlapping_regions(needle, |region| { + let same = mprotected == region.state.mprotected; + printerrln!( + "does {:?} have mprotected=={}? {}", + region.range, + mprotected, + if same { "yes" } else { "no" } + ); + same + }) +} + +#[no_mangle] +pub extern "C" fn memory_map_region_get_prot(map: &MemoryMap, needle: Range) -> u32 { + let mut prot = None; + let same = map.all_overlapping_regions(needle, |region| match prot { + None => { + prot = Some(region.state.prot); + true + } + Some(prot) => prot == region.state.prot, + }); + if same { + prot.unwrap() + } else { + PROT_INDETERMINATE + } +} + +/** memory_map_region_get_prot found no or multiple protections in the given range */ +pub const PROT_INDETERMINATE: u32 = 0xFFFFFFFFu32; + #[no_mangle] pub extern "C" fn memory_map_unmap_region(map: &mut MemoryMap, needle: Range) -> bool { map.split_out_region(needle).is_some() } #[no_mangle] -pub extern "C" fn memory_map_add_region(map: &mut MemoryMap, range: Range, owner_pkey: u8) -> bool { +pub extern "C" fn memory_map_add_region( + map: &mut MemoryMap, + range: Range, + owner_pkey: u8, + prot: u32, +) -> bool { map.add_region( range, State { owner_pkey, + mprotected: false, pkey_mprotected: false, + prot, }, ) } @@ -270,8 +324,9 @@ pub extern "C" fn memory_map_split_region( map: &mut MemoryMap, range: Range, owner_pkey: u8, + prot: u32, ) -> bool { - map.split_region(range, owner_pkey).is_some() + map.split_region(range, owner_pkey, prot).is_some() } #[no_mangle] @@ -291,3 +346,21 @@ pub extern "C" fn memory_map_pkey_mprotect_region( false } } + +#[no_mangle] +pub extern "C" fn memory_map_mprotect_region(map: &mut MemoryMap, range: Range, prot: u32) -> bool { + if let Some(mut state) = map.split_out_region(range) { + if state.mprotected == false { + state.mprotected = true; + state.prot = prot; + map.add_region(range, state) + } else { + printerrln!("already mprotected, prot {} => {}", state.prot, prot); + state.mprotected = true; + state.prot = prot; + map.add_region(range, state) + } + } else { + false + } +} diff --git a/runtime/memory_map.h b/runtime/memory_map.h index 1b33cab16..ef60211f5 100644 --- a/runtime/memory_map.h +++ b/runtime/memory_map.h @@ -9,6 +9,11 @@ #include /* clang-format off */ +/** + * memory_map_region_get_prot found no or multiple protections in the given range + */ +#define MEMORY_MAP_PROT_INDETERMINATE 4294967295u + struct memory_map; struct range { @@ -28,12 +33,26 @@ bool memory_map_all_overlapping_regions_pkey_mprotected(const struct memory_map struct range needle, bool pkey_mprotected); +bool memory_map_all_overlapping_regions_mprotected(const struct memory_map *map, + struct range needle, + bool mprotected); + +uint32_t memory_map_region_get_prot(const struct memory_map *map, struct range needle); + bool memory_map_unmap_region(struct memory_map *map, struct range needle); -bool memory_map_add_region(struct memory_map *map, struct range range, uint8_t owner_pkey); +bool memory_map_add_region(struct memory_map *map, + struct range range, + uint8_t owner_pkey, + uint32_t prot); -bool memory_map_split_region(struct memory_map *map, struct range range, uint8_t owner_pkey); +bool memory_map_split_region(struct memory_map *map, + struct range range, + uint8_t owner_pkey, + uint32_t prot); bool memory_map_pkey_mprotect_region(struct memory_map *map, struct range range, uint8_t pkey); +bool memory_map_mprotect_region(struct memory_map *map, struct range range, uint32_t prot); + /* clang-format on */ diff --git a/runtime/track_memory_map.c b/runtime/track_memory_map.c index f2567bcff..c67a8127c 100644 --- a/runtime/track_memory_map.c +++ b/runtime/track_memory_map.c @@ -29,13 +29,22 @@ bool is_op_permitted(struct memory_map *map, int event, map, info->mremap.old_range, info->mremap.pkey)) return true; break; - case EVENT_MPROTECT: - if (memory_map_all_overlapping_regions_have_pkey(map, info->mprotect.range, - info->mprotect.pkey)) + case EVENT_MPROTECT: { + /* allow mprotecting memory that has not been mprotected */ + bool impacts_only_unprotected_memory = + memory_map_all_overlapping_regions_mprotected(map, info->mprotect.range, + false); + if (impacts_only_unprotected_memory) + return true; + + /* allow mprotecting memory that is already writable */ + uint32_t prot = memory_map_region_get_prot(map, info->mprotect.range); + if (prot != MEMORY_MAP_PROT_INDETERMINATE && (prot & PROT_WRITE)) return true; break; + } case EVENT_PKEY_MPROTECT: { - /* allow mprotecting memory that we own to our pkey */ + /* allow mprotecting memory that has not been pkey_mprotected to our pkey */ bool impacts_only_unprotected_memory = memory_map_all_overlapping_regions_pkey_mprotected( map, info->pkey_mprotect.range, false); @@ -62,27 +71,48 @@ bool update_memory_map(struct memory_map *map, int event, union event_info *info) { switch (event) { case EVENT_MMAP: - return memory_map_add_region(map, info->mmap.range, info->mmap.pkey); + if (info->mmap.flags & MAP_FIXED) { + // mapping a fixed address is allowed to overlap/split existing regions + if (!memory_map_split_region(map, info->mmap.range, info->mmap.pkey, + info->mmap.prot)) { + fprintf(stderr, "no split, adding region\n"); + return memory_map_add_region(map, info->mmap.range, info->mmap.pkey, + info->mmap.prot); + } else { + return true; + } + } else { + return memory_map_add_region(map, info->mmap.range, info->mmap.pkey, + info->mmap.prot); + } break; case EVENT_MUNMAP: return memory_map_unmap_region(map, info->munmap.range); break; - case EVENT_MREMAP: + case EVENT_MREMAP: { + uint32_t prot = memory_map_region_get_prot(map, info->mremap.old_range); + if (prot == MEMORY_MAP_PROT_INDETERMINATE) { + fprintf(stderr, "could not find prot for region to mremap\n"); + exit(1); + } + /* we don't need to handle MREMAP_MAYMOVE specially because we don't assume the old and new ranges have the same start */ /* similarly, MREMAP_FIXED simply lets the request dictate the new range's start addr, about which we make no assumptions */ if (info->mremap.flags & MREMAP_DONTUNMAP) { return memory_map_add_region(map, info->mremap.new_range, - info->mremap.pkey); + info->mremap.pkey, prot); } else { memory_map_unmap_region(map, info->mremap.old_range); return memory_map_add_region(map, info->mremap.new_range, - info->mremap.pkey); + info->mremap.pkey, prot); } break; + } case EVENT_MPROTECT: - return true; + return memory_map_mprotect_region(map, info->mprotect.range, + info->mprotect.prot); break; case EVENT_PKEY_MPROTECT: { return memory_map_pkey_mprotect_region(map, info->mprotect.range,