diff --git a/mpi/gen_mpi.rb b/mpi/gen_mpi.rb index defe8ac6..7cc6e7f6 100644 --- a/mpi/gen_mpi.rb +++ b/mpi/gen_mpi.rb @@ -1,11 +1,5 @@ require_relative 'mpi_model' -puts <<~EOF - #include - #include - #include "mpi_tracepoints.h" -EOF - def common_block(c, provider) params = c.parameters.collect(&:name) tp_params = c.parameters.collect do |p| @@ -31,10 +25,10 @@ def common_block(c, provider) if c.has_return_type? puts < + #include + #include "mpi_tracepoints.h" + #include + #include +EOF + +define_and_find_mpi_symbols + +puts File::read(File.join(SRC_DIR,"tracer_mpi_helpers.include.c")) + $mpi_commands.each do |c| - next if c.name.start_with?("PMPI") normal_wrapper(c, :lttng_ust_mpi) end diff --git a/mpi/mpi_model.rb b/mpi/mpi_model.rb index e1c3d102..6dc8bb5f 100644 --- a/mpi/mpi_model.rb +++ b/mpi/mpi_model.rb @@ -25,7 +25,7 @@ mpi_funcs_e = $mpi_api["functions"] -INIT_FUNCTIONS=/None/ +INIT_FUNCTIONS=/MPI_Init|MPI_Init_thread/ $mpi_meta_parameters = YAML::load_file(File.join(SRC_DIR, "mpi_meta_parameters.yaml")) $mpi_meta_parameters.fetch("meta_parameters",[]).each { |func, list| @@ -39,7 +39,7 @@ } def upper_snake_case(str) - str.gsub(/([A-Z][A-Z0-9]*)/, '_\1').upcase + str.gsub(/([a-z][a-z0-9]*)/, '_\1').upcase end MPI_POINTER_NAMES = $mpi_commands.collect { |c| diff --git a/mpi/tracer_mpi_helpers.include.c b/mpi/tracer_mpi_helpers.include.c new file mode 100644 index 00000000..bdbdea6e --- /dev/null +++ b/mpi/tracer_mpi_helpers.include.c @@ -0,0 +1,48 @@ +static pthread_once_t _init = PTHREAD_ONCE_INIT; +static __thread volatile int in_init = 0; +static volatile unsigned int _initialized = 0; + +static void _load_tracer(void) { + char *s = NULL; + void *handle = NULL; + int verbose = 0; + + s = getenv("LTTNG_UST_MPI_LIBMPI_LOADER"); + if (s) + handle = dlopen(s, RTLD_LAZY | RTLD_LOCAL | RTLD_DEEPBIND); + else + handle = dlopen("libmpi.so", RTLD_LAZY | RTLD_LOCAL | RTLD_DEEPBIND); + if (handle) { + void* ptr = dlsym(handle, "MPI_Init"); + if (ptr == (void*)&MPI_Init) { //opening oneself + dlclose(handle); + handle = NULL; + } + } + + if( !handle ) { + fprintf(stderr, "THAPI: Failure: could not load MPI library!\n"); + exit(1); + } + + s = getenv("LTTNG_UST_MPI_VERBOSE"); + if (s) + verbose = 1; + + find_mpi_symbols(handle, verbose); +} + +static inline void _init_tracer(void) { + if( __builtin_expect (_initialized, 1) ) + return; + /* Avoid reentrancy */ + if (!in_init) { + in_init=1; + __sync_synchronize(); + pthread_once(&_init, _load_tracer); + __sync_synchronize(); + in_init=0; + } + _initialized = 1; +} +