Skip to content

Commit

Permalink
Support VNNI pre-encoded input in AMX lowering. (#210)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Jan 13, 2025
1 parent daa7eb0 commit dc8dfb6
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 75 deletions.

Large diffs are not rendered by default.

223 changes: 223 additions & 0 deletions test/TritonCPU/dot-to-amx.mlir

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions third_party/cpu/language/cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils import vnni_decode

__all__ = ["vnni_decode"]
22 changes: 22 additions & 0 deletions third_party/cpu/language/cpu/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from triton import jit
import triton.language as tl
from triton.language.core import builtin


@jit
def _vnni_decode(arg0):
tl.static_assert(len(arg0.shape) == 2)
tmp = arg0.reshape((arg0.shape[0], arg0.shape[1] // 2, 2))
tmp1, tmp2 = tl.split(tmp)
return tl.join(tmp1.T, tmp2.T).reshape((arg0.shape[1] // 2, arg0.shape[0] * 2)).T


@builtin
def vnni_decode(arg0, _builder=None, _generator=None):
bitwidth = arg0.dtype.primitive_bitwidth
if bitwidth > 16:
raise ValueError("Expected 8-bit or 16-bit values for vnni_decode")
decoded = _generator.call_JitFunction(_vnni_decode, (arg0, ), kwargs={})
if bitwidth == 8:
decoded = _generator.call_JitFunction(_vnni_decode, (decoded, ), kwargs={})
return decoded
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,80 @@ Value getInitAccValue(Value val) {
return forOp.getInitArgs()[initValIdx];
}

MemBuffer findInputBuffer(Value val, bool allowTransposed) {
namespace {

// Check if val is a result of transpose operation. If it is, then return
// a source of that transpose operation. Otherwise, return nullptr.
Value getTransposedSrc(Value val) {
auto transposeOp = val.getDefiningOp<vector::TransposeOp>();
if (transposeOp)
return transposeOp.getVector();
return nullptr;
}

// We are looking for the following sequence:
// %tmp1, %tmp2 = vector.deinterleave %src
// %tmp3 = vector.transpose %tmp1, [1, 0]
// %tmp4 = vector.transpose %tmp2, [1, 0]
// %tmp5 = vector.interleave %tmp3, %tmp4
// %val = vector.transpose %tmp5, [1, 0]
// and return %src if pattern matching succeeds.
Value getVnniSrcImpl(Value val) {
auto transposedVal = getTransposedSrc(val);
if (!transposedVal)
return nullptr;

auto interleave = transposedVal.getDefiningOp<vector::InterleaveOp>();
if (!interleave)
return nullptr;

auto tmp1 = getTransposedSrc(interleave.getLhs());
auto tmp2 = getTransposedSrc(interleave.getRhs());
if (!tmp1 || !tmp2)
return nullptr;

auto deinterleave1 = tmp1.getDefiningOp<vector::DeinterleaveOp>();
auto deinterleave2 = tmp2.getDefiningOp<vector::DeinterleaveOp>();
if (!deinterleave1 || deinterleave1 != deinterleave2 ||
deinterleave1.getResult(0) != tmp1 || deinterleave2.getResult(1) != tmp2)
return nullptr;

return deinterleave1.getSource();
}

} // namespace

Value getVnniSrc(Value val) {
Type elemTy = getElementTypeOrSelf(val.getType());

// VNNI encoding is used for 8-bit and 16-bit values only.
if (elemTy.getIntOrFloatBitWidth() > 16)
return nullptr;

// For 16-bit values VNNI encoding is a single interleave of
// subsequenct rows. For 8-bit values, it's applied twice.
Value encoded = getVnniSrcImpl(val);
if (encoded && elemTy.getIntOrFloatBitWidth() == 8)
encoded = getVnniSrcImpl(encoded);

return encoded;
}

MemBuffer findInputBuffer(Value val, bool allowTransposed, bool allowVnni) {
MemBuffer buf;

if (allowTransposed) {
auto transposeOp = val.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
val = transposeOp.getVector();
auto transposed = getTransposedSrc(val);
if (transposed) {
val = transposed;
buf.transposed = true;
}
} else if (allowVnni) {
auto vnniVal = getVnniSrc(val);
if (vnniVal) {
val = vnniVal;
buf.vnni = true;
}
}

auto valLoad = val.getDefiningOp<vector::TransferReadOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ struct MemBuffer {
SmallVector<Value> step;
// True if buffer holds transposed value.
bool transposed = false;
// Ttue if buffer holds value in VNNI (interleaved to groups of 32bit)
// encoding.
bool vnni = false;

bool empty() const { return !memRef; }
};
Expand All @@ -48,10 +51,17 @@ template <typename T> bool hasMaskOrBoundsCheck(T op) {
return hasBoundsCheck || mask;
}

// Search for a buffer holding required value. If allowTransposed is true,
// then buffer is allowed to hold both transposed and not transposed value.
// Search for a buffer holding required value.
//
// If allowTransposed is true, then buffer is allowed to hold both transposed
// and not transposed value.
//
// If allowVnni then buffer is allowed to hold value in both original and
// VNNI-encoded form. This flag is ignored if allowTransposed is true.
//
// Return empty buffer if no memory holding value was found.
MemBuffer findInputBuffer(Value val, bool allowTransposed = false);
MemBuffer findInputBuffer(Value val, bool allowTransposed = false,
bool allowVnni = false);

// Cast vector to a specified element type using ext or trunc
// operations. Return the original value if it already matches
Expand All @@ -67,6 +77,10 @@ MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy,
Value shiftIndex(Location loc, Value index, int64_t offs,
PatternRewriter &rewriter);

// Check if val is a result of a sequence that performs VNNI decoding.
// If it is, then return the original encoded value. Otherwise, return nullptr.
Value getVnniSrc(Value val);

} // namespace cpu
} // namespace triton
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,9 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave,
LDBG("Preparing buffer (interleave=" << interleave
<< ") for a vector: " << val);
auto vecTy = cast<VectorType>(val.getType());
MemBuffer inputBuf = findInputBuffer(val);
MemBuffer inputBuf = findInputBuffer(val, false, interleave);
if (!inputBuf.empty()) {
if (interleave) {
if (interleave && !inputBuf.vnni) {
LDBG(" Copying from the original memref with interleave: "
<< inputBuf.memRef);
auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy),
Expand All @@ -426,7 +426,12 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave,
MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter);

if (interleave) {
interleaveAndStore(loc, val, buf.memRef, rewriter);
auto interleavedVal = getVnniSrc(val);
if (interleavedVal) {
LDBG(" Using pre-encoding value: " << interleavedVal);
op_write(interleavedVal, buf.memRef, buf.indices);
} else
interleaveAndStore(loc, val, buf.memRef, rewriter);
} else {
op_write(val, buf.memRef, buf.indices);
}
Expand Down
21 changes: 5 additions & 16 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,31 +177,20 @@ struct SplitOpConversion : public OpConversionPattern<triton::SplitOp> {
auto src = rewriter.getRemappedValue(op.getSrc());
auto srcTy = cast<VectorType>(src.getType());
auto resTy = getTypeConverter()->convertType(op.getType(0));
assert(srcTy.getShape().back() == 2);

SmallVector<Value> results;
if (srcTy.getRank() == 1) {
results.push_back(rewriter.create<vector::ExtractOp>(loc, src, 0));
results.push_back(rewriter.create<vector::ExtractOp>(loc, src, 1));
rewriter.replaceOp(op, results);
} else {
SmallVector<int64_t> tmpShape({srcTy.getNumElements()});
SmallVector<int64_t> tmpShape(srcTy.getShape().drop_back());
tmpShape.back() *= 2;
auto tmp = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get(tmpShape, srcTy.getElementType()), src);

SmallVector<int64_t> evenIndices;
SmallVector<int64_t> oddIndices;
for (int64_t i = 0; i < srcTy.getNumElements(); i += 2) {
evenIndices.push_back(i);
oddIndices.push_back(i + 1);
}

Value res1 =
rewriter.create<vector::ShuffleOp>(loc, tmp, tmp, evenIndices);
Value res2 =
rewriter.create<vector::ShuffleOp>(loc, tmp, tmp, oddIndices);
results.push_back(rewriter.create<vector::ShapeCastOp>(loc, resTy, res1));
results.push_back(rewriter.create<vector::ShapeCastOp>(loc, resTy, res2));
rewriter.replaceOpWithNewOp<vector::DeinterleaveOp>(op, tmp);
}
rewriter.replaceOp(op, results);
return success();
}
};
Expand Down

0 comments on commit dc8dfb6

Please sign in to comment.