Skip to content

Commit

Permalink
[CombToAIG] Lower comb.icmp
Browse files Browse the repository at this point in the history
This commit adds a pattern for icmp. Lowering is straight-forward
but fairly complicated. LEC is verified
  • Loading branch information
uenoku committed Dec 24, 2024
1 parent 40f5f03 commit fb1c2a4
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 4 deletions.
28 changes: 28 additions & 0 deletions integration_test/circt-synth/comb-lowering-lec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,31 @@ hw.module @mul(in %arg0: i3, in %arg1: i3, in %arg2: i3, out add: i3) {
%0 = comb.mul %arg0, %arg1, %arg2 : i3
hw.output %0 : i3
}

// RUN: circt-lec %t.mlir %s -c1=icmp_eq_ne -c2=icmp_eq_ne --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ICMP_EQ_NE
// COMB_ICMP_EQ_NE: c1 == c2
hw.module @icmp_eq_ne(in %lhs: i3, in %rhs: i3, out out_eq: i1, out out_ne: i1) {
%eq = comb.icmp eq %lhs, %rhs : i3
%ne = comb.icmp ne %lhs, %rhs : i3
hw.output %eq, %ne : i1, i1
}

// RUN: circt-lec %t.mlir %s -c1=icmp_unsigned_compare -c2=icmp_unsigned_compare --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ICMP_UNSIGNED_COMPARE
// COMB_ICMP_UNSIGNED_COMPARE: c1 == c2
hw.module @icmp_unsigned_compare(in %lhs: i3, in %rhs: i3, out out_ugt: i1, out out_uge: i1, out out_ult: i1, out out_ule: i1) {
%ugt = comb.icmp ugt %lhs, %rhs : i3
%uge = comb.icmp uge %lhs, %rhs : i3
%ult = comb.icmp ult %lhs, %rhs : i3
%ule = comb.icmp ule %lhs, %rhs : i3
hw.output %ugt, %uge, %ult, %ule : i1, i1, i1, i1
}

// RUN: circt-lec %t.mlir %s -c1=icmp_signed_compare -c2=icmp_signed_compare --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ICMP_SIGNED_COMPARE
// COMB_ICMP_SIGNED_COMPARE: c1 == c2
hw.module @icmp_signed_compare(in %lhs: i3, in %rhs: i3, out out_sgt: i1, out out_sge: i1, out out_slt: i1, out out_sle: i1) {
%sgt = comb.icmp sgt %lhs, %rhs : i3
%sge = comb.icmp sge %lhs, %rhs : i3
%slt = comb.icmp slt %lhs, %rhs : i3
%sle = comb.icmp sle %lhs, %rhs : i3
hw.output %sgt, %sge, %slt, %sle : i1, i1, i1, i1
}
118 changes: 116 additions & 2 deletions lib/Conversion/CombToAIG/CombToAIG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,119 @@ struct CombMulOpConversion : OpConversionPattern<MulOp> {
}
};

struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
using OpConversionPattern<ICmpOp>::OpConversionPattern;
static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
ArrayRef<Value> bBits, bool isLess,
bool includeEq,
ConversionPatternRewriter &rewriter) {
// Construct following unsigned comparison expressions.
// a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
// a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
// a >= b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] >= b[n-1:0])
// a > b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] > b[n-1:0])
Value acc =
rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), includeEq);

for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
auto aBitXorBBit =
rewriter.createOrFold<comb::XorOp>(op.getLoc(), aBit, bBit, true);
auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
op.getLoc(), aBitXorBBit, true);
auto pred = rewriter.createOrFold<aig::AndInverterOp>(
op.getLoc(), aBit, bBit, isLess, !isLess);

auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
op.getLoc(), ValueRange{aEqualB, acc}, true);
acc = rewriter.createOrFold<comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
true);
}
return acc;
}

LogicalResult
matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto lhs = adaptor.getLhs();
auto rhs = adaptor.getRhs();

switch (op.getPredicate()) {
default:
return failure();

case ICmpPredicate::eq:
case ICmpPredicate::ceq: {
// a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
auto xorBits = extractBits(rewriter, xorOp);
SmallVector<bool> allInverts(xorBits.size(), true);
rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, xorBits, allInverts);
return success();
}

case ICmpPredicate::ne:
case ICmpPredicate::cne: {
// a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
rewriter.replaceOpWithNewOp<comb::OrOp>(op, extractBits(rewriter, xorOp),
true);
return success();
}

case ICmpPredicate::uge:
case ICmpPredicate::ugt:
case ICmpPredicate::ule:
case ICmpPredicate::ult: {
bool isLess = op.getPredicate() == ICmpPredicate::ult ||
op.getPredicate() == ICmpPredicate::ule;
bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
op.getPredicate() == ICmpPredicate::ule;
auto aBits = extractBits(rewriter, lhs);
auto bBits = extractBits(rewriter, rhs);
rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess,
includeEq, rewriter));
return success();
}
case ICmpPredicate::slt:
case ICmpPredicate::sle:
case ICmpPredicate::sgt:
case ICmpPredicate::sge: {
if (lhs.getType().getIntOrFloatBitWidth() == 0)
return rewriter.notifyMatchFailure(
op.getLoc(), "i0 signed comparison is unsupported");
bool isLess = op.getPredicate() == ICmpPredicate::slt ||
op.getPredicate() == ICmpPredicate::sle;
bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
op.getPredicate() == ICmpPredicate::sle;

auto aBits = extractBits(rewriter, lhs);
auto bBits = extractBits(rewriter, rhs);

// Get a sign bit
auto signA = aBits.back();
auto signB = bBits.back();

// Compare magnitudes (all bits except sign)
auto sameSignResult = constructUnsignedCompare(
op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
includeEq, rewriter);

// XOR of signs: true if signs are different
auto signsDiffer =
rewriter.create<comb::XorOp>(op.getLoc(), signA, signB);

// Result when signs are different
Value diffSignResult = isLess ? signA : signB;

// Final result: choose based on whether signs differ
rewriter.replaceOpWithNewOp<comb::MuxOp>(op, signsDiffer, diffSignResult,
sameSignResult);
return success();
}
}
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -350,9 +463,10 @@ static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
CombMuxOpConversion,
// Arithmetic Ops
CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
CombICmpOpConversion,
// Variadic ops that must be lowered to binary operations
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>, CombLowerVariadicOp<MulOp>>(
patterns.getContext());
CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
CombLowerVariadicOp<MulOp>>(patterns.getContext());
}

void ConvertCombToAIGPass::runOnOperation() {
Expand Down
77 changes: 75 additions & 2 deletions test/Conversion/CombToAIG/comb-to-aig-arith.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux}))" | FileCheck %s
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add}))" | FileCheck %s --check-prefix=ALLOW_ADD
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux},cse))" | FileCheck %s
// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-comb-to-aig{additional-legal-ops=comb.xor,comb.or,comb.and,comb.mux,comb.add},cse))" | FileCheck %s --check-prefix=ALLOW_ADD


// CHECK-LABEL: @add
Expand Down Expand Up @@ -46,3 +46,76 @@ hw.module @mul(in %lhs: i2, in %rhs: i2, out out: i2) {
%0 = comb.mul %lhs, %rhs : i2
hw.output %0 : i2
}

// CHECK-LABEL: @icmp_eq_ne
hw.module @icmp_eq_ne(in %lhs: i2, in %rhs: i2, out out_eq: i1, out out_ne: i1) {
%eq = comb.icmp eq %lhs, %rhs : i2
%ne = comb.icmp ne %lhs, %rhs : i2
// CHECK-NEXT: %[[XOR:.+]] = comb.xor %lhs, %rhs
// CHECK-NEXT: %[[XOR_0:.+]] = comb.extract %[[XOR]] from 0 : (i2) -> i1
// CHECK-NEXT: %[[XOR_1:.+]] = comb.extract %[[XOR]] from 1 : (i2) -> i1
// CHECK-NEXT: %[[EQ:.+]] = aig.and_inv not %[[XOR_0]], not %[[XOR_1]]
// CHECK-NEXT: %[[NEQ:.+]] = comb.or bin %[[XOR_0]], %[[XOR_1]]
// CHECK-NEXT: hw.output %[[EQ]], %[[NEQ]]
// CHECK-NEXT: }
hw.output %eq, %ne : i1, i1
}

// CHECK-LABEL: @icmp_unsigned_compare
hw.module @icmp_unsigned_compare(in %lhs: i2, in %rhs: i2, out out_ugt: i1, out out_uge: i1, out out_ult: i1, out out_ule: i1) {
%ugt = comb.icmp ugt %lhs, %rhs : i2
%uge = comb.icmp uge %lhs, %rhs : i2
%ult = comb.icmp ult %lhs, %rhs : i2
%ule = comb.icmp ule %lhs, %rhs : i2
// CHECK-NEXT: %[[LHS_0:.+]] = comb.extract %lhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[LHS_1:.+]] = comb.extract %lhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[RHS_0:.+]] = comb.extract %rhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[RHS_1:.+]] = comb.extract %rhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[LSB_NEQ:.+]] = comb.xor bin %[[LHS_0]], %[[RHS_0]]
// CHECK-NEXT: %[[LSB_GT:.+]] = aig.and_inv %[[LHS_0]], not %[[RHS_0]]
// CHECK-NEXT: %[[MSB_NEQ:.+]] = comb.xor bin %[[LHS_1]], %[[RHS_1]]
// CHECK-NEXT: %[[MSB_EQ:.+]] = aig.and_inv not %[[MSB_NEQ]]
// CHECK-NEXT: %[[MSB_GT:.+]] = aig.and_inv %[[LHS_1]], not %[[RHS_1]]
// CHECK-NEXT: %[[MSB_EQ_AND_LSB_GT:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_GT]]
// CHECK-NEXT: %[[UGT:.+]] = comb.or bin %[[MSB_GT]], %[[MSB_EQ_AND_LSB_GT]]
// CHECK-NEXT: %[[LSB_EQ:.+]] = aig.and_inv not %[[LSB_NEQ]]
// CHECK-NEXT: %[[LSB_UGE:.+]] = comb.or bin %[[LSB_GT]], %[[LSB_EQ]]
// CHECK-NEXT: %[[MSB_EQ_AND_LSB_UGE:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_UGE]]
// CHECK-NEXT: %[[UGE:.+]] = comb.or bin %[[MSB_GT]], %[[MSB_EQ_AND_LSB_UGE]]
// CHECK-NEXT: %[[LSB_LT:.+]] = aig.and_inv not %[[LHS_0]], %[[RHS_0]]
// CHECK-NEXT: %[[MSB_LT:.+]] = aig.and_inv not %[[LHS_1]], %[[RHS_1]]
// CHECK-NEXT: %[[MSB_EQ_AND_LSB_LT:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_LT]]
// CHECK-NEXT: %[[ULT:.+]] = comb.or bin %[[MSB_LT]], %[[MSB_EQ_AND_LSB_LT]]
// CHECK-NEXT: %[[LSB_LE:.+]] = comb.or bin %[[LSB_LT]], %[[LSB_EQ]]
// CHECK-NEXT: %[[MSB_EQ_AND_LSB_LE:.+]] = comb.and bin %[[MSB_EQ]], %[[LSB_LE]]
// CHECK-NEXT: %[[ULE:.+]] = comb.or bin %[[MSB_LT]], %[[MSB_EQ_AND_LSB_LE]]
// CHECK-NEXT: hw.output %[[UGT]], %[[UGE]], %[[ULT]], %[[ULE]]
// CHECK-NEXT: }
hw.output %ugt, %uge, %ult, %ule : i1, i1, i1, i1
}

// CHECK-LABEL: @icmp_signed_compare
hw.module @icmp_signed_compare(in %lhs: i2, in %rhs: i2, out out_sgt: i1, out out_sge: i1, out out_slt: i1, out out_sle: i1) {
%sgt = comb.icmp sgt %lhs, %rhs : i2
%sge = comb.icmp sge %lhs, %rhs : i2
%slt = comb.icmp slt %lhs, %rhs : i2
%sle = comb.icmp sle %lhs, %rhs : i2
// CHECK-NEXT: %[[LHS_0:.+]] = comb.extract %lhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[LHS_1:.+]] = comb.extract %lhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[RHS_0:.+]] = comb.extract %rhs from 0 : (i2) -> i1
// CHECK-NEXT: %[[RHS_1:.+]] = comb.extract %rhs from 1 : (i2) -> i1
// CHECK-NEXT: %[[LSB_NEQ:.+]] = comb.xor bin %[[LHS_0]], %[[RHS_0]]
// CHECK-NEXT: %[[LSB_GT:.+]] = aig.and_inv %[[LHS_0]], not %[[RHS_0]]
// CHECK-NEXT: %[[SIGN_NEQ:.+]] = comb.xor %[[LHS_1]], %[[RHS_1]]
// CHECK-NEXT: %[[SGT:.+]] = comb.mux %[[SIGN_NEQ]], %[[RHS_1]], %[[LSB_GT]]
// CHECK-NEXT: %[[LSB_EQ:.+]] = aig.and_inv not %[[LSB_NEQ]]
// CHECK-NEXT: %[[LSB_GE:.+]] = comb.or bin %[[LSB_GT]], %[[LSB_EQ]]
// CHECK-NEXT: %[[SGE:.+]] = comb.mux %[[SIGN_NEQ]], %[[RHS_1]], %[[LSB_GE]]
// CHECK-NEXT: %[[LSB_LT:.+]] = aig.and_inv not %[[LHS_0]], %[[RHS_0]]
// CHECK-NEXT: %[[SLT:.+]] = comb.mux %[[SIGN_NEQ]], %[[LHS_1]], %[[LSB_LT]]
// CHECK-NEXT: %[[LSB_LE:.+]] = comb.or bin %[[LSB_LT]], %[[LSB_EQ]]
// CHECK-NEXT: %[[SLE:.+]] = comb.mux %[[SIGN_NEQ]], %[[LHS_1]], %[[LSB_LE]]
// CHECK-NEXT: hw.output %[[SGT]], %[[SGE]], %[[SLT]], %[[SLE]]
// CHECK-NEXT: }
hw.output %sgt, %sge, %slt, %sle : i1, i1, i1, i1
}

0 comments on commit fb1c2a4

Please sign in to comment.