Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lux] Multithreading for JIT code #31

Open
mratsim opened this issue Jul 14, 2019 · 0 comments
Open

[Lux] Multithreading for JIT code #31

mratsim opened this issue Jul 14, 2019 · 0 comments

Comments

@mratsim
Copy link
Owner

mratsim commented Jul 14, 2019

This issues track multithreading solution for JIT code.

Description

At the moment, Lux only target Nim and so can make use of OpenMP for threading.

In the future, Lux will probably add a JIT solution via LLVM IR, this will reduce code-size, code generation for specialized size and allow targeting new architectures that would otherwise require complex C extensions.

Example: Cuda introduce __global, magic like blockId * blockDim.x + threadIdx.x that requires some gymnastics for Nim to generate code and not throw "undefined".

Unfortunately, when doing JIT on CPU we lose OpenMP support as OpenMP is implemented in Clangand replaced by libraries call in LLVM IR. So we need an alternative solution.

Solutions to explore

  1. Reuse Nim threadpool library.

  2. Implement a threading library from scratch

  3. Wrap a C/C++ library (note that C++ will cause issues with cpuinfo with some compiler due to it using C99)

  4. Wait for OpenMP IR to be merged in LLVM see:

OpenMP code transformation

from https://stackoverflow.com/questions/52285368/how-does-llvm-translate-openmp-multi-threaded-code-with-runtime-library-calls

This OMP code

extern float foo( void );
int main () {
    int i;
    float r = 0.0;
    #pragma omp parallel for schedule(dynamic) reduction(+:r)
    for ( i = 0; i < 10; i ++ ) {
        r += foo();
    }
}

is transformed into

extern float foo( void );
int main () {
    static int zero = 0;
    auto int gtid;
    auto float r = 0.0;
    __kmpc_begin( & loc3, 0 );
    // The gtid is not actually required in this example so could be omitted;
    // We show its initialization here because it is often required for calls into
    // the runtime and should be locally cached like this.
    gtid = __kmpc_global thread num( & loc3 );
    __kmpc_fork call( & loc7, 1, main_7_parallel_3, & r );
    __kmpc_end( & loc0 );
    return 0;
}

struct main_10_reduction_t_5 { float r_10_rpr; };

static kmp_critical_name lck = { 0 };
static ident_t loc10; // loc10.flags should contain KMP_IDENT_ATOMIC_REDUCE bit set
                      // if compiler has generated an atomic reduction.
void main_7_parallel_3( int *gtid, int *btid, float *r_7_shp ) {
    auto int i_7_pr;
    auto int lower, upper, liter, incr;
    auto struct main_10_reduction_t_5 reduce;
    reduce.r_10_rpr = 0.F;
    liter = 0;
    __kmpc_dispatch_init_4( & loc7,*gtid, 35, 0, 9, 1, 1 );
    while ( __kmpc_dispatch_next_4( & loc7, *gtid, & liter, & lower, & upper, & incr
      ) ) {
        for( i_7_pr = lower; upper >= i_7_pr; i_7_pr ++ )
          reduce.r_10_rpr += foo();
    }
    switch( __kmpc_reduce_nowait( & loc10, *gtid, 1, 4, & reduce, main_10_reduce_5, &lck ) ) {
        case 1:
           *r_7_shp += reduce.r_10_rpr;
           __kmpc_end_reduce_nowait( & loc10, *gtid, & lck );
           break;
        case 2:
           __kmpc_atomic_float4_add( & loc10, *gtid, r_7_shp, reduce.r_10_rpr );
           break;
        default:;
    }
}

in LLVM IR:

[...]
; Function Attrs: nounwind uwtable
define dso_local i32 @main() local_unnamed_addr #0 {
entry:
  %i = alloca i32, align 4
  %r = alloca float, align 4
  %0 = bitcast i32* %i to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %0) #4
  %1 = bitcast float* %r to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %1) #4
  store float 0.000000e+00, float* %r, align 4, !tbaa !2
  call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @0, i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, float*)* @.omp_outlined. to void (i32*, i32*, ...)*), i32* nonnull %i, float* nonnull %r) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %1) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %0) #4
  ret i32 0
}
[...]

; Function Attrs: norecurse nounwind uwtable
define internal void @.omp_outlined.(i32* noalias nocapture readonly %.global_tid., i32* noalias nocapture readnone %.bound_tid., i32* nocapture readnone dereferenceable(4) %i, float* nocapture dereferenceable(4) %r) #2 {
entry:
  %.omp.lb = alloca i32, align 4
  %.omp.ub = alloca i32, align 4
  %.omp.stride = alloca i32, align 4
  %.omp.is_last = alloca i32, align 4
  %r1 = alloca float, align 4
  %.omp.reduction.red_list = alloca [1 x i8*], align 8
  %0 = bitcast i32* %.omp.lb to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %0) #4
  store i32 0, i32* %.omp.lb, align 4, !tbaa !6
  %1 = bitcast i32* %.omp.ub to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %1) #4
  store i32 9, i32* %.omp.ub, align 4, !tbaa !6
  %2 = bitcast i32* %.omp.stride to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %2) #4
  store i32 1, i32* %.omp.stride, align 4, !tbaa !6
  %3 = bitcast i32* %.omp.is_last to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %3) #4
  store i32 0, i32* %.omp.is_last, align 4, !tbaa !6
  %4 = bitcast float* %r1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %4) #4
  store float 0.000000e+00, float* %r1, align 4, !tbaa !2
  %5 = load i32, i32* %.global_tid., align 4, !tbaa !6
  tail call void @__kmpc_dispatch_init_4(%struct.ident_t* nonnull @0, i32 %5, i32 35, i32 0, i32 9, i32 1, i32 1) #4
  %6 = call i32 @__kmpc_dispatch_next_4(%struct.ident_t* nonnull @0, i32 %5, i32* nonnull %.omp.is_last, i32* nonnull %.omp.lb, i32* nonnull %.omp.ub, i32* nonnull %.omp.stride) #4
  %tobool14 = icmp eq i32 %6, 0
  br i1 %tobool14, label %omp.dispatch.end, label %omp.dispatch.body

omp.dispatch.cond.loopexit:                       ; preds = %omp.inner.for.body, %omp.dispatch.body
  %7 = call i32 @__kmpc_dispatch_next_4(%struct.ident_t* nonnull @0, i32 %5, i32* nonnull %.omp.is_last, i32* nonnull %.omp.lb, i32* nonnull %.omp.ub, i32* nonnull %.omp.stride) #4
  %tobool = icmp eq i32 %7, 0
  br i1 %tobool, label %omp.dispatch.end, label %omp.dispatch.body

omp.dispatch.body:                                ; preds = %entry, %omp.dispatch.cond.loopexit
  %8 = load i32, i32* %.omp.lb, align 4, !tbaa !6
  %9 = load i32, i32* %.omp.ub, align 4, !tbaa !6, !llvm.mem.parallel_loop_access !8
  %cmp12 = icmp sgt i32 %8, %9
  br i1 %cmp12, label %omp.dispatch.cond.loopexit, label %omp.inner.for.body

omp.inner.for.body:                               ; preds = %omp.dispatch.body, %omp.inner.for.body
  %.omp.iv.013 = phi i32 [ %add4, %omp.inner.for.body ], [ %8, %omp.dispatch.body ]
  %call = call float @foo() #4, !llvm.mem.parallel_loop_access !8
  %10 = load float, float* %r1, align 4, !tbaa !2, !llvm.mem.parallel_loop_access !8
  %add3 = fadd float %call, %10
  store float %add3, float* %r1, align 4, !tbaa !2, !llvm.mem.parallel_loop_access !8
  %add4 = add nsw i32 %.omp.iv.013, 1
  %11 = load i32, i32* %.omp.ub, align 4, !tbaa !6, !llvm.mem.parallel_loop_access !8
  %cmp = icmp slt i32 %.omp.iv.013, %11
  br i1 %cmp, label %omp.inner.for.body, label %omp.dispatch.cond.loopexit, !llvm.loop !8

omp.dispatch.end:                                 ; preds = %omp.dispatch.cond.loopexit, %entry
  %12 = bitcast [1 x i8*]* %.omp.reduction.red_list to float**
  store float* %r1, float** %12, align 8
  %13 = bitcast [1 x i8*]* %.omp.reduction.red_list to i8*
  %14 = call i32 @__kmpc_reduce_nowait(%struct.ident_t* nonnull @1, i32 %5, i32 1, i64 8, i8* nonnull %13, void (i8*, i8*)* nonnull @.omp.reduction.reduction_func, [8 x i32]* nonnull @.gomp_critical_user_.reduction.var) #4
  switch i32 %14, label %.omp.reduction.default [
    i32 1, label %.omp.reduction.case1
    i32 2, label %.omp.reduction.case2
  ]

.omp.reduction.case1:                             ; preds = %omp.dispatch.end
  %15 = load float, float* %r, align 4, !tbaa !2
  %16 = load float, float* %r1, align 4, !tbaa !2
  %add5 = fadd float %15, %16
  store float %add5, float* %r, align 4, !tbaa !2
  call void @__kmpc_end_reduce_nowait(%struct.ident_t* nonnull @1, i32 %5, [8 x i32]* nonnull @.gomp_critical_user_.reduction.var) #4
  br label %.omp.reduction.default

.omp.reduction.case2:                             ; preds = %omp.dispatch.end
  %17 = bitcast float* %r to i32*
  %atomic-load = load atomic i32, i32* %17 monotonic, align 4, !tbaa !2
  %18 = load float, float* %r1, align 4, !tbaa !2
  br label %atomic_cont

atomic_cont:                                      ; preds = %atomic_cont, %.omp.reduction.case2
  %19 = phi i32 [ %atomic-load, %.omp.reduction.case2 ], [ %23, %atomic_cont ]
  %20 = bitcast i32 %19 to float
  %add7 = fadd float %18, %20
  %21 = bitcast float %add7 to i32
  %22 = cmpxchg i32* %17, i32 %19, i32 %21 monotonic monotonic
  %23 = extractvalue { i32, i1 } %22, 0
  %24 = extractvalue { i32, i1 } %22, 1
  br i1 %24, label %.omp.reduction.default, label %atomic_cont

.omp.reduction.default:                           ; preds = %atomic_cont, %.omp.reduction.case1, %omp.dispatch.end
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %4) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %3) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %2) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %1) #4
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %0) #4
  ret void
}
[...]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant