diff --git a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp index 3caf075..3ed702c 100644 --- a/lib/Conversion/TorchToTcp/TcpCustomOp.cpp +++ b/lib/Conversion/TorchToTcp/TcpCustomOp.cpp @@ -154,14 +154,15 @@ class ConvertAtenFakeQuantizePerTensorAffineTensorQparamsOp helper.addIntAttr("quant_max", op.getQuantMax()); // scale - auto scaleTy = adaptor.getScale().dyn_cast(); + auto scaleTy = adaptor.getScale().getType().dyn_cast(); if (!scaleTy || scaleTy.getShape().size() != 1 || scaleTy.getNumElements() != 1) return rewriter.notifyMatchFailure(op, "Unsupported scale type or size"); helper.addOperand("scale", adaptor.getScale()); // zero_point - auto zeroPointTy = adaptor.getZeroPoint().dyn_cast(); + auto zeroPointTy = + adaptor.getZeroPoint().getType().dyn_cast(); if (!zeroPointTy || zeroPointTy.getShape().size() != 1 || zeroPointTy.getNumElements() != scaleTy.getNumElements()) return rewriter.notifyMatchFailure(op, @@ -188,13 +189,14 @@ class ConvertAtenFakeQuantizePerChannelAffineOp helper.addIntAttr("quant_max", op.getQuantMax()); // scale - auto scaleTy = adaptor.getScale().dyn_cast(); + auto scaleTy = adaptor.getScale().getType().dyn_cast(); if (!scaleTy || scaleTy.getShape().size() != 1) return rewriter.notifyMatchFailure(op, "Unsupported scale type or size"); helper.addOperand("scale", adaptor.getScale()); // zero_point - auto zeroPointTy = adaptor.getZeroPoint().dyn_cast(); + auto zeroPointTy = + adaptor.getZeroPoint().getType().dyn_cast(); if (!zeroPointTy || zeroPointTy.getShape().size() != 1 || zeroPointTy.getNumElements() != scaleTy.getNumElements()) return rewriter.notifyMatchFailure(op,