Skip to content

Commit

Permalink
[core] Fix various bugs when processing the host-side functions and t…
Browse files Browse the repository at this point in the history
…he (#2268)

* Fix various bugs when processing the host-side functions and the
native calling conventions.

As support for more types is added, it is required that the host
side code have the same calling conventions as the native clang++
compiler. Otherwise, programs will simply fail to execute correctly.

This is part 1. The launch kernel execution path is broken, and
those bugs will be addressed in the next PR.

Signed-off-by: Eric Schweitz <[email protected]>

* Workaround -Werror

Signed-off-by: Eric Schweitz <[email protected]>

* Fix CI Werror issue.

Signed-off-by: Eric Schweitz <[email protected]>

* The requires line didn't work as expected. Script the restriction
as a workaround.

Signed-off-by: Eric Schweitz <[email protected]>

* The assets build doesn't respect the "uname -m" check for some
unknown reason, so just elide the test completely.

Signed-off-by: Eric Schweitz <[email protected]>

* Review comment.

Signed-off-by: Eric Schweitz <[email protected]>

---------

Signed-off-by: Eric Schweitz <[email protected]>
  • Loading branch information
schweitzpgi authored Oct 16, 2024
1 parent 7bdab51 commit 03bfdf9
Show file tree
Hide file tree
Showing 11 changed files with 661 additions and 441 deletions.
1 change: 1 addition & 0 deletions docker/build/assets.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ RUN cd /cuda-quantum && source scripts/configure_build.sh && \
# The tests is marked correctly as requiring nvcc, but since nvcc
# is available during the build we need to filter it manually.
filtered=" --filter-out MixedLanguage/cuda-1"; \
filtered+="|AST-Quake/calling_convention"; \
fi && \
"$LLVM_INSTALL_PREFIX/bin/llvm-lit" -v build/test \
--param nvqpp_site_config=build/test/lit.site.cfg.py ${filtered} && \
Expand Down
7 changes: 6 additions & 1 deletion include/cudaq/Optimizer/Builder/Factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ createMonotonicLoop(mlir::OpBuilder &builder, mlir::Location loc,

bool hasHiddenSRet(mlir::FunctionType funcTy);

/// Check a function to see if argument 0 has the `sret` attribute. Typically,
/// one may find this on a host-side entry point function.
bool hasSRet(mlir::func::FuncOp funcOp);

/// Convert the function type \p funcTy to a signature compatible with the code
/// on the host side. This will add hidden arguments, such as the `this`
/// pointer, convert some results to `sret` pointers, etc.
Expand All @@ -251,7 +255,8 @@ bool isX86_64(mlir::ModuleOp);
bool isAArch64(mlir::ModuleOp);

/// A small structure may be passed as two arguments on the host side. (e.g., on
/// the X86-64 ABI.) If \p ty is not a `struct`, this returns `false`.
/// the X86-64 ABI.) If \p ty is not a `struct`, this returns `false`. Note
/// also, some small structs may be packed into a single register.
bool structUsesTwoArguments(mlir::Type ty);

std::optional<std::int64_t> getIntIfConstant(mlir::Value value);
Expand Down
176 changes: 123 additions & 53 deletions lib/Optimizer/Builder/Factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ using namespace mlir;

namespace cudaq::opt {

// The common small struct limit for architectures cudaq is supporting.
static constexpr unsigned CommonSmallStructSize = 128;

bool factory::isX86_64(ModuleOp module) {
std::string triple;
if (auto ta = module->getAttr(targetTripleAttrName))
Expand Down Expand Up @@ -302,33 +305,6 @@ cc::LoopOp factory::createMonotonicLoop(
return loop;
}

// FIXME: some ABIs may return a small struct in registers rather than via an
// sret pointer.
//
// On x86_64,
// pair of: argument return value packed from msb to lsb
// i32 : i64 i64 (second, first)
// i64 : i64, i64 { i64, i64 }
// f32 : <2 x float> <2 x float>
// f64 : double, double { double, double }
//
// On aarch64,
// pair of: argument return value packed from msb to lsb
// i32 : i64 i64 (second, first)
// i64 : [2 x i64] [2 x i64]
// f32 : [2 x float] { float, float }
// f64 : [2 x double] { double, double }
bool factory::hasHiddenSRet(FunctionType funcTy) {
// If a function has more than 1 result, the results are promoted to a
// structured return argument. Otherwise, if there is 1 result and it is an
// aggregate type, then it is promoted to a structured return argument.
auto numResults = funcTy.getNumResults();
return numResults > 1 ||
(numResults == 1 && funcTy.getResult(0)
.isa<cc::SpanLikeType, cc::StructType,
cc::ArrayType, cc::CallableType>());
}

cc::StructType factory::stlStringType(MLIRContext *ctx) {
auto i8Ty = IntegerType::get(ctx, 8);
auto ptrI8Ty = cc::PointerType::get(i8Ty);
Expand Down Expand Up @@ -361,8 +337,8 @@ Type factory::getSRetElementType(FunctionType funcTy) {
auto *ctx = funcTy.getContext();
if (funcTy.getNumResults() > 1)
return cc::StructType::get(ctx, funcTy.getResults());
if (isa<cc::SpanLikeType>(funcTy.getResult(0)))
return getDynamicBufferType(ctx);
if (auto spanTy = dyn_cast<cc::SpanLikeType>(funcTy.getResult(0)))
return stlVectorType(spanTy.getElementType());
return funcTy.getResult(0);
}

Expand Down Expand Up @@ -403,33 +379,49 @@ static Type convertToHostSideType(Type ty) {
// function tries to simulate GCC argument passing conventions. classify() also
// has a number of FIXME comments, where it diverges from the referenced ABI.
// Empirical evidence show that on x86_64, integers and floats are packed in
// integers of size 32 or 64 together, unless the float member fits by itself.
// integers of size 8, 16, 24, 32 or 64 together, unless the float member fits
// by itself.
static bool shouldExpand(SmallVectorImpl<Type> &packedTys,
cc::StructType structTy) {
if (structTy.isEmpty())
return false;
auto *ctx = structTy.getContext();
unsigned bits = 0;
auto scaleBits = [&](unsigned size) {
if (size < 32)
size = (size + 7) & ~7u;
if (size > 32 && size <= 64)
size = 64;
return size;
};

// First split the members into a "lo" set and a "hi" set.
SmallVector<Type> set1;
SmallVector<Type> set2;
for (auto ty : structTy.getMembers()) {
if (auto intTy = dyn_cast<IntegerType>(ty)) {
bits += intTy.getWidth();
if (bits <= 64)
auto addBits = scaleBits(intTy.getWidth());
if (bits + addBits <= 64) {
bits += addBits;
set1.push_back(ty);
else
} else {
bits = std::max(bits, 64u) + addBits;
set2.push_back(ty);
}
} else if (auto fltTy = dyn_cast<FloatType>(ty)) {
bits += fltTy.getWidth();
if (bits <= 64)
auto addBits = fltTy.getWidth();
if (bits + addBits <= 64) {
bits += addBits;
set1.push_back(ty);
else
} else {
bits = std::max(bits, 64u) + addBits;
set2.push_back(ty);
}
} else {
return false;
}
if (bits > CommonSmallStructSize)
return false;
}

// Process the sets. If the set has anything integral, use integer. If the set
Expand All @@ -441,28 +433,83 @@ static bool shouldExpand(SmallVectorImpl<Type> &packedTys,
return true;
return false;
};
auto intSetSize = [&](auto theSet) {
unsigned size = 0;
for (auto ty : theSet)
size += scaleBits(ty.getIntOrFloatBitWidth());
return size;
};
auto processMembers = [&](auto theSet, unsigned packIdx) {
if (useInt(theSet)) {
packedTys[packIdx] = IntegerType::get(ctx, bits > 32 ? 64 : 32);
auto size = intSetSize(theSet);
if (size <= 32)
packedTys[packIdx] = IntegerType::get(ctx, size);
else
packedTys[packIdx] = IntegerType::get(ctx, 64);
} else if (theSet.size() == 1) {
packedTys[packIdx] = theSet[0];
} else {
assert(theSet[0] == FloatType::getF32(ctx) && "must be float");
packedTys[packIdx] =
VectorType::get(ArrayRef<std::int64_t>{2}, theSet[0]);
}
};
assert(!set1.empty() && "struct must have members");
packedTys.resize(set2.empty() ? 1 : 2);
processMembers(set1, 0);
if (!set2.empty())
processMembers(set2, 1);
if (set2.empty())
return false;
processMembers(set2, 1);
return true;
}

bool factory::hasSRet(func::FuncOp funcOp) {
if (funcOp.getNumArguments() > 0)
if (auto dict = funcOp.getArgAttrDict(0))
return dict.contains(LLVM::LLVMDialect::getStructRetAttrName());
return false;
}

// On x86_64,
// pair of: argument return value packed from msb to lsb
// i32 : i64 i64 (second, first)
// i64 : i64, i64 { i64, i64 }
// f32 : <2 x float> <2 x float>
// f64 : double, double { double, double }
// ptr : ptr, ptr { ptr, ptr }
//
// On aarch64,
// pair of: argument return value packed from msb to lsb
// i32 : i64 i64 (second, first)
// i64 : [2 x i64] [2 x i64]
// f32 : [2 x float] { float, float }
// f64 : [2 x double] { double, double }
// ptr : [2 x i64] [2 x i64]
bool factory::hasHiddenSRet(FunctionType funcTy) {
// If a function has more than 1 result, the results are promoted to a
// structured return argument. Otherwise, if there is 1 result and it is an
// aggregate type, then it is promoted to a structured return argument.
auto numResults = funcTy.getNumResults();
if (numResults == 0)
return false;
if (numResults > 1)
return true;
auto resTy = funcTy.getResult(0);
if (resTy.isa<cc::SpanLikeType, cc::ArrayType, cc::CallableType>())
return true;
if (auto strTy = dyn_cast<cc::StructType>(resTy)) {
SmallVector<Type> packedTys;
bool inRegisters = shouldExpand(packedTys, strTy) || !packedTys.empty();
return !inRegisters;
}
return false;
}

bool factory::structUsesTwoArguments(mlir::Type ty) {
// Unchecked! This is only valid if target is X86-64.
auto structTy = dyn_cast<cc::StructType>(ty);
if (!structTy || structTy.getBitSize() == 0 || structTy.getBitSize() > 128)
if (!structTy || structTy.getBitSize() == 0 ||
structTy.getBitSize() > CommonSmallStructSize)
return false;
SmallVector<Type> unused;
return shouldExpand(unused, structTy);
Expand All @@ -486,14 +533,32 @@ FunctionType factory::toHostSideFuncType(FunctionType funcTy, bool addThisPtr,
auto *ctx = funcTy.getContext();
SmallVector<Type> inputTys;
bool hasSRet = false;
if (factory::hasHiddenSRet(funcTy)) {
// When the kernel is returning a std::vector<T> result, the result is
// returned via a sret argument in the first position. When this argument
// is added, the this pointer becomes the second argument. Both are opaque
// pointers at this point.
auto eleTy = convertToHostSideType(getSRetElementType(funcTy));
inputTys.push_back(cc::PointerType::get(eleTy));
hasSRet = true;
Type resultTy;
if (funcTy.getNumResults() == 1)
if (auto strTy = dyn_cast<cc::StructType>(funcTy.getResult(0)))
if (strTy.getBitSize() != 0 &&
strTy.getBitSize() <= CommonSmallStructSize) {
SmallVector<Type, 2> packedTys;
if (shouldExpand(packedTys, strTy) || !packedTys.empty()) {
if (packedTys.size() == 1)
resultTy = packedTys[0];
else
resultTy = cc::StructType::get(ctx, packedTys);
}
}
if (!resultTy && funcTy.getNumResults()) {
if (factory::hasHiddenSRet(funcTy)) {
// When the kernel is returning a std::vector<T> result, the result is
// returned via a sret argument in the first position. When this argument
// is added, the this pointer becomes the second argument. Both are opaque
// pointers at this point.
auto eleTy = convertToHostSideType(getSRetElementType(funcTy));
inputTys.push_back(cc::PointerType::get(eleTy));
hasSRet = true;
} else {
assert(funcTy.getNumResults() == 1);
resultTy = funcTy.getResult(0);
}
}
// If this kernel is a plain old function or a static member function, we
// don't want to add a hidden `this` argument.
Expand All @@ -509,20 +574,25 @@ FunctionType factory::toHostSideFuncType(FunctionType funcTy, bool addThisPtr,
// On x86_64 and aarch64, a struct that is smaller than 128 bits may be
// passed in registers as separate arguments. See classifyArgumentType()
// in CodeGen/TargetInfo.cpp.
if (strTy.getBitSize() != 0 && strTy.getBitSize() <= 128) {
if (strTy.getBitSize() != 0 &&
strTy.getBitSize() <= CommonSmallStructSize) {
if (isX86_64(module)) {
SmallVector<Type, 2> packedTys;
if (shouldExpand(packedTys, strTy)) {
for (auto ty : packedTys)
inputTys.push_back(ty);
continue;
} else if (!packedTys.empty()) {
for (auto ty : packedTys)
inputTys.push_back(ty);
continue;
}
} else {
assert(isAArch64(module) && "aarch64 expected");
if (onlyArithmeticMembers(strTy)) {
// Empirical evidence shows that on aarch64, arguments are packed
// into a single i64 or a [2 x i64] typed value based on the size of
// the struct. This is regardless of whether the value(s) are
// into a single i64 or a [2 x i64] typed value based on the size
// of the struct. This is regardless of whether the value(s) are
// floating-point or not.
if (strTy.getBitSize() > 64)
inputTys.push_back(cc::ArrayType::get(ctx, i64Ty, 2));
Expand All @@ -542,8 +612,8 @@ FunctionType factory::toHostSideFuncType(FunctionType funcTy, bool addThisPtr,
// and it hasn't been converted to a hidden sret argument.
if (funcTy.getNumResults() == 0 || hasSRet)
return FunctionType::get(ctx, inputTys, {});
assert(funcTy.getNumResults() == 1);
return FunctionType::get(ctx, inputTys, funcTy.getResults());
assert(funcTy.getNumResults() == 1 && resultTy);
return FunctionType::get(ctx, inputTys, resultTy);
}

bool factory::isStdVecArg(Type type) {
Expand Down
Loading

0 comments on commit 03bfdf9

Please sign in to comment.