From 846ff1d8fb3cc47783263064fad3a954bafcdd80 Mon Sep 17 00:00:00 2001 From: Bruce Lai Date: Tue, 19 Nov 2024 21:59:45 -0800 Subject: [PATCH] Fix error by checking for m.size() Signed-off-by: Bruce Lai --- .../src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp index 91a6ea92eacb..c1a7fa5a432f 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp @@ -1334,6 +1334,10 @@ getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn, // TODO: support int data type return; } + FailureOr cDims = + linalg::inferContractionDims(op); + if (failed(cDims) || cDims->m.size() != 1) + return; // Use 7 x lmul4 to fully utilize vector registers. sizes[0] = 7; // Calculate tile size for the main vector dimension (N). @@ -1342,10 +1346,6 @@ getMatmulRISCVVectorSizes(mlir::FunctionOpInterface entryPointFn, (nativeVectorSize * 2 * kByteSizeInBits) / elementSize; sizes[1] = maxNumberElementsForLMUL4; sizes[2] = 1; - FailureOr cDims = - linalg::inferContractionDims(op); - if (failed(cDims)) - return; ArrayRef lhsShape = op.getShape(op.getDpsInputOperand(0)); // If m = 1, set tile size to 1 x lmul8 if (lhsShape[cDims->m[0]] == 1) {