Skip to content

Commit

Permalink
Add threadpool performance test (#8447)
Browse files Browse the repository at this point in the history
* Add ability to get and set thread pool size from JIT-land

* Add test of various parallel scenarios

* Fix atomics test

* Fix comment
  • Loading branch information
abadams authored Oct 30, 2024
1 parent 402a056 commit 978b39c
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 16 deletions.
28 changes: 28 additions & 0 deletions src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,24 @@ void JITModule::reuse_device_allocations(bool b) const {
}
}

int JITModule::get_num_threads() const {
std::map<std::string, Symbol>::const_iterator f =
exports().find("halide_get_num_threads");
if (f != exports().end()) {
return (reinterpret_bits<int (*)()>(f->second.address))();
}
return 1;
}

int JITModule::set_num_threads(int n) const {
std::map<std::string, Symbol>::const_iterator f =
exports().find("halide_set_num_threads");
if (f != exports().end()) {
return (reinterpret_bits<int (*)(int)>(f->second.address))(n);
}
return 1;
}

bool JITModule::compiled() const {
return jit_module->JIT != nullptr;
}
Expand Down Expand Up @@ -1075,6 +1093,16 @@ void JITSharedRuntime::reuse_device_allocations(bool b) {
shared_runtimes(MainShared).reuse_device_allocations(b);
}

int JITSharedRuntime::get_num_threads() {
std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
return shared_runtimes(MainShared).get_num_threads();
}

int JITSharedRuntime::set_num_threads(int n) {
std::lock_guard<std::mutex> lock(shared_runtimes_mutex);
return shared_runtimes(MainShared).set_num_threads(n);
}

JITCache::JITCache(Target jit_target,
std::vector<Argument> arguments,
std::map<std::string, JITExtern> jit_externs,
Expand Down
18 changes: 18 additions & 0 deletions src/JITModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ struct JITModule {
/** See JITSharedRuntime::reuse_device_allocations */
void reuse_device_allocations(bool) const;

/** See JITSharedRuntime::get_num_threads */
int get_num_threads() const;

/** See JITSharedRuntime::set_num_threads */
int set_num_threads(int) const;

/** Return true if compile_module has been called on this module. */
bool compiled() const;
};
Expand Down Expand Up @@ -279,6 +285,18 @@ class JITSharedRuntime {
static void reuse_device_allocations(bool);

static void release_all();

/** Get the number of threads in the Halide thread pool. Includes the
* calling thread. Meaningless if a custom_do_par_for has been set. */
static int get_num_threads();

/** Set the number of threads to use in the Halide thread pool, inclusive of
* the calling thread. Pass zero to use a reasonable default (typically the
* number of CPUs online). Calling this is meaningless if custom_do_par_for
* has been set. Halide may launch more threads than this if necessary to
* avoid deadlock when using the async scheduling directive. Returns the old
* number. */
static int set_num_threads(int);
};

void *get_symbol_address(const char *s);
Expand Down
5 changes: 4 additions & 1 deletion src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ extern struct halide_thread *halide_spawn_thread(void (*f)(void *), void *closur
/** Join a thread. */
extern void halide_join_thread(struct halide_thread *);

/** Set the number of threads used by Halide's thread pool. Returns
/** Get or set the number of threads used by Halide's thread pool. Set returns
* the old number.
*
* n < 0 : error condition
Expand All @@ -402,7 +402,10 @@ extern void halide_join_thread(struct halide_thread *);
* of halide_do_par_for(); custom implementations may completely ignore values
* passed to halide_set_num_threads().)
*/
// @{
extern int halide_get_num_threads();
extern int halide_set_num_threads(int n);
// @}

/** Halide calls these functions to allocate and free memory. To
* replace in AOT code, use the halide_set_custom_malloc and
Expand Down
1 change: 1 addition & 0 deletions src/runtime/runtime_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = {
(void *)&halide_get_cpu_features,
(void *)&halide_get_gpu_device,
(void *)&halide_get_library_symbol,
(void *)&halide_get_num_threads,
(void *)&halide_get_symbol,
(void *)&halide_get_trace_file,
(void *)&halide_hexagon_detach_device_handle,
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/thread_pool_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,13 @@ WEAK int halide_set_num_threads(int n) {
return old;
}

WEAK int halide_get_num_threads() {
halide_mutex_lock(&work_queue.mutex);
int n = work_queue.desired_threads_working;
halide_mutex_unlock(&work_queue.mutex);
return n;
}

WEAK void halide_shutdown_thread_pool() {
if (work_queue.initialized) {
// Wake everyone up and tell them the party's over and it's time
Expand Down
13 changes: 4 additions & 9 deletions test/correctness/atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,15 +1196,10 @@ int main(int argc, char **argv) {
}

Target target = get_jit_target_from_environment();
// Most of the schedules used in this test are terrible for large
// thread count machines, due to massive amounts of
// contention. We'll just set the thread count to 4. Unfortunately
// there's no JIT api for this yet.
#ifdef _WIN32
_putenv_s("HL_NUM_THREADS", "4");
#else
setenv("HL_NUM_THREADS", "4", 1);
#endif
// Most of the schedules used in this test are terrible for large
// thread count machines, due to massive amounts of
// contention. We'll just set the thread count to 4.
Halide::Internal::JITSharedRuntime::set_num_threads(4);
test_all<uint8_t>(Backend::CPU);
test_all<uint8_t>(Backend::CPUVectorize);
test_all<int8_t>(Backend::CPU);
Expand Down
1 change: 1 addition & 0 deletions test/performance/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ tests(GROUPS performance multithreaded
matrix_multiplication.cpp
memory_profiler.cpp
parallel_performance.cpp
parallel_scenarios.cpp
profiler.cpp
rfactor.cpp
sort.cpp
Expand Down
7 changes: 1 addition & 6 deletions test/performance/inner_loop_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ int main(int argc, char **argv) {
for (int t = 2; t <= 64; t *= 2) {
std::ostringstream ss;
ss << "HL_NUM_THREADS=" << t;
std::string str = ss.str();
char buf[32] = {0};
memcpy(buf, str.c_str(), str.size());
putenv(buf);
p.invalidate_cache();
Halide::Internal::JITSharedRuntime::release_all();
Halide::Internal::JITSharedRuntime::set_num_threads(t);

p.compile_jit();
// Start the thread pool without giving any hints as to the
Expand Down
98 changes: 98 additions & 0 deletions test/performance/parallel_scenarios.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#include "Halide.h"
#include "halide_thread_pool.h"

using namespace Halide;

int main(int argc, char **argv) {
Param<int> inner_iterations, outer_iterations, memory_limit;
ImageParam input(Float(32), 1);

Func f, g;
Var x;

RDom r(0, inner_iterations);
// Make an inner loop with a floating point sqrt, some integer
// multiply-adds, and a random int generation, and a random memory access.
f(x) = sum(sqrt(input(random_int(r) % memory_limit)));

g() = f(0) + f(outer_iterations - 1);

f.compute_root().parallel(x);

auto out = Runtime::Buffer<float>::make_scalar();
const int max_memory = 100 * 1024 * 1024;
Runtime::Buffer<float> in(max_memory);
in.fill(17.0f);

auto callable = g.compile_to_callable({inner_iterations, outer_iterations, memory_limit, input});

// We want the full distribution of runtimes, not the denoised min, so we
// won't use Tools::benchmark here.

int native_threads = Halide::Internal::JITSharedRuntime::get_num_threads();

auto bench = [&](bool m, bool c, int i, int o) {
const int num_samples = 128;
const int memory_limit = m ? max_memory : 128;

auto bench_one = [&]() {
auto t1 = std::chrono::high_resolution_clock::now();
callable(i, o, memory_limit, in, out);
auto t2 = std::chrono::high_resolution_clock::now();
return 1e9 * std::chrono::duration<float>(t2 - t1).count() / (i * o);
};

std::vector<float> times(num_samples);
if (c) {
Halide::Tools::ThreadPool<void> thread_pool;
const int num_tasks = 8;
const int samples_per_task = num_samples / num_tasks;
Halide::Internal::JITSharedRuntime::set_num_threads(num_tasks * native_threads);
std::vector<std::future<void>> futures(num_tasks);
for (size_t t = 0; t < futures.size(); t++) {
futures[t] = thread_pool.async(
[&](size_t t) {
bench_one();
for (int s = 0; s < samples_per_task; s++) {
size_t idx = t * samples_per_task + s;
times[idx] = bench_one();
}
},
t);
}
for (auto &f : futures) {
f.get();
}
} else {
Halide::Internal::JITSharedRuntime::set_num_threads(native_threads);
bench_one();
for (int s = 0; s < num_samples; s++) {
times[s] = bench_one();
}
}
std::sort(times.begin(), times.end());
printf("%d %d %d %d ", m, c, i, o);
const int n = 8;
int off = (num_samples / n) / 2;
for (int i = 0; i < n; i++) {
printf("%g ", times[off + (num_samples * i) / n]);
}
printf("\n");
};

// The output is designed to be copy-pasted into a spreadsheet, not read by a human
printf("memory_bound contended inner outer t0 t1 t2 t3 t4 t5 t7\n");
for (bool contended : {false, true}) {
for (bool memory_bound : {false, true}) {
for (int i : {1 << 0, 1 << 6, 1 << 12, 1 << 18}) {
for (int o : {1, 2, 4, 8, 16, 32, 64, 128, 256}) {
bench(memory_bound, contended, i, o);
}
}
}
}

printf("Success!\n");

return 0;
}

0 comments on commit 978b39c

Please sign in to comment.