Skip to content

Commit

Permalink
[CombToAIG] Add support for div/mod operations (#8130)
Browse files Browse the repository at this point in the history
This patch adds support for lowering comb div/mod operations to AIG:
* Support unsigned div/mod for power-of-2 constant divisors 
* Support for non-power-of-2 divisors by emulating the operation with a mux tree when the number of unknown bits is small (default threshold is 10 bits). This is not ideal and eventually we should lower to a proper div/mod. Practically this lowering pattern covers many cases since users rarely use div/mod because these operations are very expensive (users write their own div/mod in their frontend language).  
* LEC tests are added.
  • Loading branch information
uenoku authored Feb 3, 2025
1 parent 1b33c44 commit ee62dbc
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 8 deletions.
2 changes: 2 additions & 0 deletions include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,8 @@ def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> {
let options = [
ListOption<"additionalLegalOps", "additional-legal-ops", "std::string",
"Specify additional legal ops for testing">,
Option<"maxEmulationUnknownBits", "max-emulation-unknown-bits", "uint32_t", "10",
"Maximum number of unknown bits to emulate in a table lookup">
];
}

Expand Down
53 changes: 53 additions & 0 deletions integration_test/circt-synth/divmod.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// REQUIRES: libz3
// REQUIRES: circt-lec-jit

// RUN: circt-opt %s --hw-aggregate-to-comb --convert-comb-to-aig --convert-aig-to-comb -o %t.mlir

// RUN: circt-lec %t.mlir %s -c1=divmodu -c2=divmodu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMODU
// COMB_DIVMODU: c1 == c2
hw.module @divmodu(in %lhs: i3, in %rhs: i3, out out_div: i3, out out_mod: i3) {
%c0_i3 = hw.constant 0 : i3
%neq = comb.icmp ne %rhs, %c0_i3 : i3
verif.assume %neq : i1

%0 = comb.divu %lhs, %rhs : i3
%1 = comb.modu %lhs, %rhs : i3
hw.output %0, %1 : i3, i3
}

// RUN: circt-lec %t.mlir %s -c1=divmodu_power_of_two -c2=divmodu_power_of_two --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMODU_POWER_OF_TWO
// COMB_DIVMODU_POWER_OF_TWO: c1 == c2
hw.module @divmodu_power_of_two(in %lhs: i8, out out_div: i8, out out_mod: i8) {
%c16_i8 = hw.constant 16 : i8

%0 = comb.divu %lhs, %c16_i8 : i8
%1 = comb.modu %lhs, %c16_i8 : i8
hw.output %0, %1 : i8, i8
}

// RUN: circt-lec %t.mlir %s -c1=divmods -c2=divmods --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMODS
// COMB_DIVMODS: c1 == c2
hw.module @divmods(in %lhs: i3, in %rhs: i3, out out_div: i3, out out_mod: i3) {
%c0_i3 = hw.constant 0 : i3
%neq = comb.icmp ne %rhs, %c0_i3 : i3
verif.assume %neq : i1

%0 = comb.divs %lhs, %rhs : i3
%1 = comb.mods %lhs, %rhs : i3
hw.output %0, %1 : i3, i3
}

// RUN: circt-lec %t.mlir %s -c1=divmod_mix_constant -c2=divmod_mix_constant --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVMOD_MIX_CONSTANT
// COMB_DIVMOD_MIX_CONSTANT: c1 == c2
hw.module @divmod_mix_constant(in %in: i1, in %lhs: i1, in %rhs: i1, out out_divu: i4, out out_modu: i4, out out_divs: i4, out out_mods: i4) {
%c2_i2 = hw.constant 2 : i2

%new_lhs = comb.concat %in, %c2_i2, %lhs : i1, i2, i1
%new_rhs = comb.concat %c2_i2, %rhs, %in : i2, i1, i1
%0 = comb.divu %new_lhs, %new_rhs : i4
%1 = comb.modu %new_lhs, %new_rhs : i4
%2 = comb.divs %new_lhs, %new_rhs : i4
%3 = comb.mods %new_lhs, %new_rhs : i4
hw.output %0, %1, %2, %3 : i4, i4, i4, i4
}

277 changes: 269 additions & 8 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/PointerUnion.h"

namespace circt {
#define GEN_PASS_DEF_CONVERTCOMBTOAIG
Expand Down Expand Up @@ -90,6 +91,143 @@ static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
outOfBoundsValue);
}

namespace {
// A union of Value and IntegerAttr to cleanly handle constant values.
using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
} // namespace

// Return the number of unknown bits and populate the concatenated values.
static int64_t getNumUnknownBitsAndPopulateValues(
Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
// Constant or zero width value are all known.
if (value.getType().isInteger(0))
return 0;

// Recursively count unknown bits for concat.
if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
int64_t totalUnknownBits = 0;
for (auto concatInput : llvm::reverse(concat.getInputs())) {
auto unknownBits =
getNumUnknownBitsAndPopulateValues(concatInput, values);
if (unknownBits < 0)
return unknownBits;
totalUnknownBits += unknownBits;
}
return totalUnknownBits;
}

// Constant value is known.
if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
values.push_back(constant.getValueAttr());
return 0;
}

// Consider other operations as unknown bits.
// TODO: We can handle replicate, extract, etc.
values.push_back(value);
return hw::getBitWidth(value.getType());
}

// Return a value that substitutes the unknown bits with the mask.
static APInt
substitueMaskToValues(size_t width,
llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
uint32_t mask) {
uint32_t bitPos = 0, unknownPos = 0;
APInt result(width, 0);
for (auto constantOrValue : constantOrValues) {
int64_t elemWidth;
if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
elemWidth = constant.getValue().getBitWidth();
result.insertBits(constant.getValue(), bitPos);
} else {
elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
assert(elemWidth >= 0 && "unknown bit width");
assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
// Create a mask for the unknown bits.
uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
result.insertBits(APInt(elemWidth, usedBits), bitPos);
unknownPos += elemWidth;
}
bitPos += elemWidth;
}

return result;
}

// Emulate a binary operation with unknown bits using a table lookup.
// This function enumerates all possible combinations of unknown bits and
// emulates the operation for each combination.
static LogicalResult emulateBinaryOpForUnknownBits(
ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
Operation *op,
llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
SmallVector<ConstantOrValue> lhsValues, rhsValues;

assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
"op must be a single result binary operation");

auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
auto loc = op->getLoc();
auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);

// If unknown bit width is detected, abort the lowering.
if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
return failure();

int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
if (totalUnknownBits > maxEmulationUnknownBits)
return failure();

SmallVector<Value> emulatedResults;
emulatedResults.reserve(1 << totalUnknownBits);

// Emulate all possible cases.
DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
auto it = constantPool.find(attr);
if (it != constantPool.end())
return it->second;
auto constant = rewriter.create<hw::ConstantOp>(loc, value);
constantPool[attr] = constant;
return constant;
};

for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
lhsMask < lhsMaskEnd; ++lhsMask) {
APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
rhsMask < rhsMaskEnd; ++rhsMask) {
APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
// Emulate.
emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
}
}

// Create selectors for mux tree.
SmallVector<Value> selectors;
selectors.reserve(totalUnknownBits);
for (auto &concatedValues : {rhsValues, lhsValues})
for (auto valueOrConstant : concatedValues) {
auto value = dyn_cast<Value>(valueOrConstant);
if (!value)
continue;
extractBits(rewriter, value, selectors);
}

assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
"number of selectors must match");
auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
getConstant(APInt::getZero(width)));

rewriter.replaceOp(op, muxed);
return success();
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -360,6 +498,121 @@ struct CombMulOpConversion : OpConversionPattern<MulOp> {
}
};

template <typename OpTy>
struct DivModOpConversionBase : OpConversionPattern<OpTy> {
DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
: OpConversionPattern<OpTy>(context),
maxEmulationUnknownBits(maxEmulationUnknownBits) {
assert(maxEmulationUnknownBits < 32 &&
"maxEmulationUnknownBits must be less than 32");
}
const int64_t maxEmulationUnknownBits;
};

struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
LogicalResult
matchAndRewrite(DivUOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if the divisor is a power of two.
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
if (rhsConstantOp.getValue().isPowerOf2()) {
// Extract upper bits.
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
size_t width = op.getType().getIntOrFloatBitWidth();
Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), adaptor.getLhs(), extractAmount,
width - extractAmount);
Value constZero = rewriter.create<hw::ConstantOp>(
op.getLoc(), APInt::getZero(extractAmount));
rewriter.replaceOpWithNewOp<comb::ConcatOp>(
op, op.getType(), ArrayRef<Value>{constZero, upperBits});
return success();
}

// When rhs is not power of two and the number of unknown bits are small,
// create a mux tree that emulates all possible cases.
return emulateBinaryOpForUnknownBits(
rewriter, maxEmulationUnknownBits, op,
[](const APInt &lhs, const APInt &rhs) {
// Division by zero is undefined, just return zero.
if (rhs.isZero())
return APInt::getZero(rhs.getBitWidth());
return lhs.udiv(rhs);
});
}
};

struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
LogicalResult
matchAndRewrite(ModUOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if the divisor is a power of two.
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
if (rhsConstantOp.getValue().isPowerOf2()) {
// Extract lower bits.
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
size_t width = op.getType().getIntOrFloatBitWidth();
Value lowerBits = rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), adaptor.getLhs(), 0, extractAmount);
Value constZero = rewriter.create<hw::ConstantOp>(
op.getLoc(), APInt::getZero(width - extractAmount));
rewriter.replaceOpWithNewOp<comb::ConcatOp>(
op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
return success();
}

// When rhs is not power of two and the number of unknown bits are small,
// create a mux tree that emulates all possible cases.
return emulateBinaryOpForUnknownBits(
rewriter, maxEmulationUnknownBits, op,
[](const APInt &lhs, const APInt &rhs) {
// Division by zero is undefined, just return zero.
if (rhs.isZero())
return APInt::getZero(rhs.getBitWidth());
return lhs.urem(rhs);
});
}
};

struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;

LogicalResult
matchAndRewrite(DivSOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Currently only lower with emulation.
// TODO: Implement a signed division lowering at least for power of two.
return emulateBinaryOpForUnknownBits(
rewriter, maxEmulationUnknownBits, op,
[](const APInt &lhs, const APInt &rhs) {
// Division by zero is undefined, just return zero.
if (rhs.isZero())
return APInt::getZero(rhs.getBitWidth());
return lhs.sdiv(rhs);
});
}
};

struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
LogicalResult
matchAndRewrite(ModSOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Currently only lower with emulation.
// TODO: Implement a signed modulus lowering at least for power of two.
return emulateBinaryOpForUnknownBits(
rewriter, maxEmulationUnknownBits, op,
[](const APInt &lhs, const APInt &rhs) {
// Division by zero is undefined, just return zero.
if (rhs.isZero())
return APInt::getZero(rhs.getBitWidth());
return lhs.srem(rhs);
});
}
};

struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
using OpConversionPattern<ICmpOp>::OpConversionPattern;
static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
Expand Down Expand Up @@ -565,8 +818,8 @@ struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
auto sign =
rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);

// NOTE: The max shift amount is width - 1 because the sign bit is already
// shifted out.
// NOTE: The max shift amount is width - 1 because the sign bit is
// already shifted out.
auto result = createShiftLogic</*isLeftShift=*/false>(
rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
/*getPadding=*/
Expand Down Expand Up @@ -597,10 +850,13 @@ struct ConvertCombToAIGPass
void runOnOperation() override;
using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
using ConvertCombToAIGBase<ConvertCombToAIGPass>::maxEmulationUnknownBits;
};
} // namespace

static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
static void
populateCombToAIGConversionPatterns(RewritePatternSet &patterns,
uint32_t maxEmulationUnknownBits) {
patterns.add<
// Bitwise Logical Ops
CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
Expand All @@ -613,6 +869,11 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
// Variadic ops that must be lowered to binary operations
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
CombLowerVariadicOp<MulOp>>(patterns.getContext());

// Add div/mod patterns with a threshold given by the pass option.
patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
CombModSOpConversion>(patterns.getContext(),
maxEmulationUnknownBits);
}

void ConvertCombToAIGPass::runOnOperation() {
Expand All @@ -624,10 +885,10 @@ void ConvertCombToAIGPass::runOnOperation() {
target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
hw::BitcastOp, hw::ConstantOp>();

// Treat array operations as illegal. Strictly speaking, other than array get
// operation with non-const index are legal in AIG but array types prevent a
// bunch of optimizations so just lower them to integer operations. It's
// required to run HWAggregateToComb pass before this pass.
// Treat array operations as illegal. Strictly speaking, other than array
// get operation with non-const index are legal in AIG but array types
// prevent a bunch of optimizations so just lower them to integer
// operations. It's required to run HWAggregateToComb pass before this pass.
target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
hw::AggregateConstantOp>();

Expand All @@ -640,7 +901,7 @@ void ConvertCombToAIGPass::runOnOperation() {
target.addLegalOp(OperationName(opName, &getContext()));

RewritePatternSet patterns(&getContext());
populateCombToAIGConversionPatterns(patterns);
populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits);

if (failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
Loading

0 comments on commit ee62dbc

Please sign in to comment.