From 78c1c25403789f2c49e0b671310c7138ea00ee7a Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Thu, 5 Sep 2024 14:20:04 -0700 Subject: [PATCH] relax aten.view conversion constraint (#94) Only convert aten.view to tcp.custom_op when the `size` array is non-constant. The rest will be handled through torch-to-tosa. TODO: send out an upstream PR to fix` tosa.reshape` size calculation logic. --- lib/Conversion/TorchToTcp/TcpCustomOp.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index a34bd87..85b0c39 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -299,15 +299,11 @@ class ConvertAtenViewOp : public OpConversionPattern { torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter, getTypeConverter()}; Value self = adaptor.getSelf(); - auto srcType = self.getType().cast(); - auto resultType = - getTypeConverter()->convertType(op.getType()).cast(); - SmallVector size; - // static shape will be handled through TOSA dialect - if (matchPattern(op.getSize(), m_TorchListOfConstantInts(size)) && - srcType.hasStaticShape() && resultType.hasStaticShape()) - return rewriter.notifyMatchFailure(op, "only dynamic shape is supported"); + // static size array will be handled through TOSA dialect + if (matchPattern(op.getSize(), m_TorchListOfConstantInts(size))) + return rewriter.notifyMatchFailure(op, + "only non-constant size is supported"); helper.addOperand("self", self); Operation *primListOp = op.getSize().getDefiningOp();