diff --git a/src/pal/pal_open_enclave.h b/src/pal/pal_open_enclave.h index 5e37e0438..f8124f57f 100644 --- a/src/pal/pal_open_enclave.h +++ b/src/pal/pal_open_enclave.h @@ -3,8 +3,6 @@ #include "ds/address.h" #include "pal_plain.h" #ifdef OPEN_ENCLAVE -extern "C" const void* __oe_get_heap_base(); -extern "C" const void* __oe_get_heap_end(); extern "C" void* oe_memset_s(void* p, size_t p_size, int c, size_t size); extern "C" [[noreturn]] void oe_abort(); @@ -12,9 +10,19 @@ namespace snmalloc { class PALOpenEnclave { - std::atomic oe_base = nullptr; + static inline std::atomic oe_base; + static inline void* oe_end = nullptr; public: + /** + * This will be called by oe_allocator_init to set up enclave heap bounds. + */ + static void setup_initial_range(void* base, void* end) + { + oe_base = base; + oe_end = end; + } + /** * Bitmap of PalFeatures flags indicating the optional features that this * PAL supports. @@ -32,17 +40,9 @@ namespace snmalloc template void* reserve(size_t size) noexcept { - if (oe_base == 0) - { - void* dummy = NULL; - // If this CAS fails then another thread has initialised this. - oe_base.compare_exchange_strong( - dummy, const_cast(__oe_get_heap_base())); - } - void* old_base = oe_base; void* next_base; - auto end = __oe_get_heap_end(); + auto end = oe_end; do { auto new_base = old_base; diff --git a/src/test/func/fixed_region/fixed_region.cc b/src/test/func/fixed_region/fixed_region.cc index 6c9679780..4f64700b2 100644 --- a/src/test/func/fixed_region/fixed_region.cc +++ b/src/test/func/fixed_region/fixed_region.cc @@ -10,18 +10,6 @@ #endif #define assert please_use_SNMALLOC_ASSERT -void* oe_base; -void* oe_end; -extern "C" const void* __oe_get_heap_base() -{ - return oe_base; -} - -extern "C" const void* __oe_get_heap_end() -{ - return oe_end; -} - extern "C" void* oe_memset_s(void* p, size_t p_size, int c, size_t size) { UNUSED(p_size); @@ -43,8 +31,9 @@ int main() // For 1MiB superslabs, SUPERSLAB_BITS + 4 is not big enough for the example. size_t large_class = 28 - SUPERSLAB_BITS; size_t size = 1ULL << (SUPERSLAB_BITS + large_class); - oe_base = mp.reserve(large_class); - oe_end = (uint8_t*)oe_base + size; + void* oe_base = mp.reserve(large_class); + void* oe_end = (uint8_t*)oe_base + size; + PALOpenEnclave::setup_initial_range(oe_base, oe_end); std::cout << "Allocated region " << oe_base << " - " << oe_end << std::endl; auto a = ThreadAlloc::get(); diff --git a/src/test/func/two_alloc_types/alloc1.cc b/src/test/func/two_alloc_types/alloc1.cc index c86dcde02..f6966005f 100644 --- a/src/test/func/two_alloc_types/alloc1.cc +++ b/src/test/func/two_alloc_types/alloc1.cc @@ -9,3 +9,8 @@ // Redefine the namespace, so we can have two versions. #define snmalloc snmalloc_enclave #include "../../../override/malloc.cc" + +extern "C" void oe_allocator_init(void* base, void* end) +{ + snmalloc_enclave::PALOpenEnclave::setup_initial_range(base, end); +} diff --git a/src/test/func/two_alloc_types/main.cc b/src/test/func/two_alloc_types/main.cc index 341f32dc7..1ef9d1cd1 100644 --- a/src/test/func/two_alloc_types/main.cc +++ b/src/test/func/two_alloc_types/main.cc @@ -5,18 +5,6 @@ #include #include -void* oe_base; -void* oe_end; -extern "C" const void* __oe_get_heap_base() -{ - return oe_base; -} - -extern "C" const void* __oe_get_heap_end() -{ - return oe_end; -} - extern "C" void* oe_memset_s(void* p, size_t p_size, int c, size_t size) { UNUSED(p_size); @@ -28,6 +16,7 @@ extern "C" void oe_abort() abort(); } +extern "C" void oe_allocator_init(void* base, void* end); extern "C" void* host_malloc(size_t); extern "C" void host_free(void*); @@ -51,8 +40,9 @@ int main() // For 1MiB superslabs, SUPERSLAB_BITS + 2 is not big enough for the example. size_t large_class = 26 - SUPERSLAB_BITS; size_t size = 1ULL << (SUPERSLAB_BITS + large_class); - oe_base = mp.reserve(large_class); - oe_end = (uint8_t*)oe_base + size; + void* oe_base = mp.reserve(large_class); + void* oe_end = (uint8_t*)oe_base + size; + oe_allocator_init(oe_base, oe_end); std::cout << "Allocated region " << oe_base << " - " << oe_end << std::endl; // Call these functions to trigger asserts if the cast-to-self doesn't work.