Skip to content

Commit

Permalink
[flang][cuda] Allocate descriptor in managed memory on rebox block ar…
Browse files Browse the repository at this point in the history
…gument (llvm#123971)

Another case where the descriptor must be allocated with the CUF runtime
and not a simple alloca instruction.
  • Loading branch information
clementval authored Jan 22, 2025
1 parent afcbcae commit 9f83c4e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
38 changes: 20 additions & 18 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2040,19 +2040,20 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter);

if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty())
return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
operands, rewriter);
return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
operands, rewriter);
return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
inputStrides, operands, rewriter);
return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
inputStrides, operands, rewriter);
}

private:
/// Write resulting shape and base address in descriptor, and replace rebox
/// op.
llvm::LogicalResult
finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
mlir::Value base, mlir::ValueRange lbounds,
mlir::ValueRange extents, mlir::ValueRange strides,
finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
mlir::Type destBoxTy, mlir::Value dest, mlir::Value base,
mlir::ValueRange lbounds, mlir::ValueRange extents,
mlir::ValueRange strides,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
mlir::Value zero =
Expand All @@ -2075,15 +2076,15 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
dest = insertBaseAddress(rewriter, loc, dest, base);
mlir::Value result = placeInMemoryIfNotGlobalInit(
rewriter, rebox.getLoc(), destBoxTy, dest,
isDeviceAllocation(rebox.getBox(), rebox.getBox()));
isDeviceAllocation(rebox.getBox(), adaptor.getBox()));
rewriter.replaceOp(rebox, result);
return mlir::success();
}

// Apply slice given the base address, extents and strides of the input box.
llvm::LogicalResult
sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
mlir::Value base, mlir::ValueRange inputExtents,
sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
mlir::ValueRange inputStrides, mlir::ValueRange operands,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = rebox.getLoc();
Expand All @@ -2109,7 +2110,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
if (rebox.getSlice().empty())
// The array section is of the form array[%component][substring], keep
// the input array extents and strides.
return finalizeRebox(rebox, destBoxTy, dest, base,
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
/*lbounds*/ std::nullopt, inputExtents, inputStrides,
rewriter);

Expand Down Expand Up @@ -2158,15 +2159,16 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
slicedStrides.emplace_back(stride);
}
}
return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt,
slicedExtents, slicedStrides, rewriter);
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
/*lbounds*/ std::nullopt, slicedExtents, slicedStrides,
rewriter);
}

/// Apply a new shape to the data described by a box given the base address,
/// extents and strides of the box.
llvm::LogicalResult
reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
mlir::Value base, mlir::ValueRange inputExtents,
reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
mlir::ValueRange inputStrides, mlir::ValueRange operands,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::ValueRange reboxShifts{
Expand All @@ -2175,7 +2177,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
rebox.getShift().size()};
if (rebox.getShape().empty()) {
// Only setting new lower bounds.
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts,
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
inputExtents, inputStrides, rewriter);
}

Expand All @@ -2199,8 +2201,8 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
// nextStride = extent * stride;
stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
}
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents,
newStrides, rewriter);
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
newExtents, newStrides, rewriter);
}

/// Return scalar element type of the input box.
Expand Down
11 changes: 11 additions & 0 deletions flang/test/Fir/CUDA/cuda-code-gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec

// CHECK-LABEL: llvm.func @_QPouter
// CHECK: _FortranACUFAllocDescriptor

// -----

func.func @_QMm1Psub1(%arg0: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "da"}, %arg1: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "db"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}) {
%0 = fircg.ext_rebox %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
%1 = fircg.ext_rebox %arg1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
return
}

// CHECK-LABEL: llvm.func @_QMm1Psub1
// CHECK-COUNT-2: _FortranACUFAllocDescriptor

0 comments on commit 9f83c4e

Please sign in to comment.