Skip to content

Offload device1 #142696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 123 additions & 29 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,34 @@ pub(crate) fn handle_gpu_code<'ll>(
// The offload memory transfer type for each kernel
let mut o_types = vec![];
let mut kernels = vec![];
let mut region_ids = vec![];
let offload_entry_ty = add_tgt_offload_entry(&cx);
for num in 0..9 {
let kernel = cx.get_function(&format!("kernel_{num}"));
if let Some(kernel) = kernel {
o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num);
o_types.push(o);
region_ids.push(k);
kernels.push(kernel);
}
}

gen_call_handling(&cx, &kernels, &o_types);
gen_call_handling(&cx, &kernels, &o_types, &region_ids);
}

// ; Function Attrs: nounwind
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
let ti32 = cx.type_i32();
let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
let tgt_fn_ty = cx.type_func(&args, ti32);
let name = "__tgt_target_kernel";
let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
(tgt_decl, tgt_fn_ty)
}

// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
Expand Down Expand Up @@ -83,7 +101,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
offload_entry_ty
}

fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Type, Vec<&'ll llvm::Type>) {
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
let tptr = cx.type_ptr();
let ti64 = cx.type_i64();
Expand Down Expand Up @@ -118,9 +136,7 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];

cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
// For now we don't handle kernels, so for now we just add a global dummy
// to make sure that the __tgt_offload_entry is defined and handled correctly.
cx.declare_global("my_struct_global2", kernel_arguments_ty);
(kernel_arguments_ty, kernel_elements)
}

fn gen_tgt_data_mappers<'ll>(
Expand Down Expand Up @@ -187,7 +203,7 @@ fn gen_define_handling<'ll>(
kernel: &'ll llvm::Value,
offload_entry_ty: &'ll llvm::Type,
num: i64,
) -> &'ll llvm::Value {
) -> (&'ll llvm::Value, &'ll llvm::Value) {
let types = cx.func_params_types(cx.get_type_of_global(kernel));
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
// reference) types.
Expand All @@ -205,10 +221,11 @@ fn gen_define_handling<'ll>(
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
// will be 2. For now, everything is 3, until we have our frontend set up.
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add extra input ptr once, idk, figure out later)
let o_types =
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]);
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![1+2+32; num_ptr_types]);
// Next: For each function, generate these three entries. A weak constant,
// the llvm.rodata entry name, and the omp_offloading_entries value
// the llvm.rodata entry name, and the llvm_offload_entries value

let name = format!(".kernel_{num}.region_id");
let initializer = cx.get_const_i8(0);
Expand Down Expand Up @@ -242,13 +259,13 @@ fn gen_define_handling<'ll>(
llvm::set_global_constant(llglobal, true);
llvm::set_linkage(llglobal, WeakAnyLinkage);
llvm::set_initializer(llglobal, initializer);
llvm::set_alignment(llglobal, Align::ONE);
let c_section_name = CString::new(".omp_offloading_entries").unwrap();
llvm::set_alignment(llglobal, Align::EIGHT);
let c_section_name = CString::new("llvm_offload_entries").unwrap();
llvm::set_section(llglobal, &c_section_name);
o_types
(o_types, region_id)
}

fn declare_offload_fn<'ll>(
pub(crate) fn declare_offload_fn<'ll>(
cx: &'ll SimpleCx<'_>,
name: &str,
ty: &'ll llvm::Type,
Expand Down Expand Up @@ -287,15 +304,17 @@ fn gen_call_handling<'ll>(
cx: &'ll SimpleCx<'_>,
_kernels: &[&'ll llvm::Value],
o_types: &[&'ll llvm::Value],
region_ids: &[&'ll llvm::Value],
) {
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
let tptr = cx.type_ptr();
let ti32 = cx.type_i32();
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);

gen_tgt_kernel_global(&cx);
let (tgt_kernel_decl, tgt_kernel_types) = gen_tgt_kernel_global(&cx);
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);

let main_fn = cx.get_function("main");
Expand Down Expand Up @@ -329,29 +348,33 @@ fn gen_call_handling<'ll>(
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
let ty2 = cx.type_array(cx.type_i64(), num_args);
let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");

//%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");

// Step 1)
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);

// Now we allocate once per function param, a copy to be passed to one of our maps.
let mut vals = vec![];
let mut geps = vec![];
let i32_0 = cx.get_const_i32(0);
for (index, in_ty) in types.iter().enumerate() {
// get function arg, store it into the alloca, and read it.
let p = llvm::get_param(called, index as u32);
let name = llvm::get_value_name(p);
let name = str::from_utf8(&name).unwrap();
let arg_name = format!("{name}.addr");
let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);

builder.store(p, alloca, Align::EIGHT);
let val = builder.load(in_ty, alloca, Align::EIGHT);
let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
vals.push(val);
//let p = llvm::get_param(called, index as u32);
//let name = llvm::get_value_name(p);
//let name = str::from_utf8(&name).unwrap();
//let arg_name = format!("{name}.addr");
//let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);

let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
vals.push(v);
//vals.push(val);
geps.push(gep);
}

// Step 1)
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);

let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
Expand Down Expand Up @@ -421,16 +444,87 @@ fn gen_call_handling<'ll>(

// Step 3)
// Here we will add code for the actual kernel launches in a follow-up PR.
//%28 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0
//store i32 3, ptr %28, align 4
//%29 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1
//store i32 3, ptr %29, align 4
//%30 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2
//store ptr %26, ptr %30, align 8
//%31 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3
//store ptr %27, ptr %31, align 8
//%32 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4
//store ptr @.offload_sizes, ptr %32, align 8
//%33 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5
//store ptr @.offload_maptypes, ptr %33, align 8
//%34 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6
//store ptr null, ptr %34, align 8
//%35 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7
//store ptr null, ptr %35, align 8
//%36 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8
//store i64 0, ptr %36, align 8
//%37 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9
//store i64 0, ptr %37, align 8
//%38 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10
//store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %38, align 4
//%39 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11
//store [3 x i32] [i32 256, i32 0, i32 0], ptr %39, align 4
//%40 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12
//store i32 0, ptr %40, align 4
// FIXME(offload): launch kernels
let mut values = vec![];
values.push((4, cx.get_const_i32(3)));
values.push((4, cx.get_const_i32(num_args)));
values.push((8, geps.0));
values.push((8, geps.1));
values.push((8, geps.2));
values.push((8, o_types[0]));
values.push((8, cx.const_null(cx.type_ptr())));
values.push((8, cx.const_null(cx.type_ptr())));
values.push((8, cx.get_const_i64(0)));
values.push((8, cx.get_const_i64(0)));
let ti32 = cx.type_i32();
let ci32_0 = cx.get_const_i32(0);
values.push((8, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
values.push((8, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
values.push((4, cx.get_const_i32(0)));

for (i, value) in values.iter().enumerate() {
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap());
}

let args = vec![
s_ident_t,
// MAX == -1
cx.get_const_i64(u64::MAX),
cx.get_const_i32(2097152),
cx.get_const_i32(256),
region_ids[0],
a5,
];
let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
unsafe {
let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
dbg!(&next);
llvm::LLVMRustPositionAfter(builder.llbuilder, next);
let called_kernel = llvm::LLVMGetCalledValue(next).unwrap();
llvm::LLVMInstructionEraseFromParent(next);
dbg!(&called_kernel);
}

// Step 4)
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
//unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };

let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);

builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);

drop(builder);
unsafe { llvm::LLVMDeleteFunction(called) };
dbg!("survived");

// With this we generated the following begin and end mappers. We could easily generate the
// update mapper in an update.
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,7 @@ unsafe extern "C" {

// Operations on functions
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
pub(crate) fn LLVMDeleteFunction(Fn: &Value);

// Operations about llvm intrinsics
pub(crate) fn LLVMLookupIntrinsicID(Name: *const c_char, NameLen: size_t) -> c_uint;
Expand Down Expand Up @@ -1238,6 +1239,8 @@ unsafe extern "C" {
pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>;
pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock;
pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>;
pub(crate) fn LLVMGetNextInstruction(Val: &Value) -> Option<&Value>;
pub(crate) fn LLVMInstructionEraseFromParent(Val: &Value);

// Operations on call sites
pub(crate) fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint);
Expand Down
Loading
Loading