diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 677ccc4f241b..cc21f2155e46 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1059,44 +1059,44 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenEyeMOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - int64_t n; - - if (!matchPattern(op.getN(), m_TorchConstantInt(&n))) - return rewriter.notifyMatchFailure(op, - "unimplemented: n must be constant"); - int64_t m; - if (!matchPattern(op.getM(), m_TorchConstantInt(&m))) - return rewriter.notifyMatchFailure(op, - "unimplemented: m must be constant"); - Value none = rewriter.create(loc); - auto outType = dyn_cast(op.getType()); + auto outType = op.getType().dyn_cast(); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } - if (n < 0) { - return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0"); - } - if (m < 0) { - return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0"); - } - + Value none = rewriter.create(loc); auto context = op.getContext(); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); - auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type); + + int64_t n = kUnknownSize; + int64_t m = kUnknownSize; + // prioritize getting shape from output shape + if (outType.hasSizes() && outType.getSizes().size() == 2) { + n = outType.getSizes().front(); + m = outType.getSizes().back(); + } + // if output shape is not available, try to get shape from input + if (n == kUnknownSize) + matchPattern(op.getN(), m_TorchConstantInt(&n)); + if (m == kUnknownSize) + matchPattern(op.getM(), m_TorchConstantInt(&m)); + + // prepare two unsqueezed ranges that are equal on and only on the diagonal + auto rangeNSize = llvm::SmallVector({n}); + Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type); Value rangeN = rewriter.create( - loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/op.getDevice(), /*pin_memory=*/none); - auto arangeType1 = - outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type); + auto rangeMSize = llvm::SmallVector({m}); + Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type); Value rangeM = rewriter.create( - loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, + loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value constMinusOne = rewriter.create( @@ -1109,7 +1109,6 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { } Value unsqzRangeN = *unsqzTensorInfo; - // compare unsqueezed input with boundaries auto eqType = ValueTensorType::get( context, cast(op.getType()).getSizes(), IntegerType::get(context, 1));