diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index dbfb51ced..1dc013960 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -281,12 +281,28 @@ QualType BaseForwardModeVisitor::ComputeDerivativeFunctionType() { void BaseForwardModeVisitor::SetupDerivativeParameters( llvm::SmallVectorImpl& params) { const FunctionDecl* FD = m_DiffReq.Function; - for (ParmVarDecl* PVD : FD->parameters()) { + const FunctionDecl* FDPattern = nullptr; + unsigned FDPatternNumParams = 0; + if (FD->isTemplateInstantiation()) { + FDPattern = FD->getTemplateInstantiationPattern(); + FDPatternNumParams = FDPattern->getNumParams(); + } + for (unsigned i = 0, n = FD->getNumParams(); i < n; ++i) { + const ParmVarDecl* PVD = FD->getParamDecl(i); IdentifierInfo* PVDII = PVD->getIdentifier(); // Implicitly created special member functions have no parameter names. if (!PVD->getDeclName()) PVDII = CreateUniqueIdentifier("param"); + if (FDPattern) { + const ParmVarDecl* OrigPVD = + i >= FDPatternNumParams + ? FDPattern->getParamDecl(FDPatternNumParams - 1) + : FDPattern->getParamDecl(i); + if (OrigPVD->isParameterPack()) + PVDII = CreateUniqueIdentifier(PVDII->getName()); + } + auto* newPVD = CloneParmVarDecl(PVD, PVDII, /*pushOnScopeChains=*/true, /*cloneDefaultArg=*/false); @@ -295,7 +311,8 @@ void BaseForwardModeVisitor::SetupDerivativeParameters( if (PVD == m_IndependentVar) m_IndependentVar = newPVD; - if (!PVD->getDeclName()) // We can't use lookup-based replacements + if (PVD->getDeclName() != + newPVD->getDeclName()) // We can't use lookup-based replacements m_DeclReplacements[PVD] = newPVD; params.push_back(newPVD); @@ -1104,6 +1121,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { skipFirstArg = true; // For f(g(x)) = f'(x) * g'(x) + std::size_t numParams = FD->getNumParams(); Expr* Multiplier = nullptr; for (size_t i = skipFirstArg, e = CE->getNumArgs(); i < e; ++i) { const Expr* arg = CE->getArg(i); @@ -1112,7 +1130,11 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // If original argument is an RValue and function expects an RValue // parameter, then convert the cloned argument and the corresponding // derivative to RValue if they are not RValue. - QualType paramType = FD->getParamDecl(i - skipFirstArg)->getType(); + QualType paramType; + if (FD->isVariadic() && (i - skipFirstArg) >= numParams) + paramType = CE->getArg(i - skipFirstArg)->getType(); + else + paramType = FD->getParamDecl(i - skipFirstArg)->getType(); if (utils::IsRValue(arg) && paramType->isRValueReferenceType()) { if (!utils::IsRValue(argDiff.getExpr())) { Expr* castE = utils::BuildStaticCastToRValue(m_Sema, argDiff.getExpr()); diff --git a/test/ForwardMode/VariadicCall.C b/test/ForwardMode/VariadicCall.C new file mode 100644 index 000000000..8625e774d --- /dev/null +++ b/test/ForwardMode/VariadicCall.C @@ -0,0 +1,54 @@ +// RUN: %cladclang -std=c++17 %s -I%S/../../include -oVariadicCall.out 2>&1 | %filecheck %s +// RUN: ./VariadicCall.out | %filecheck_exec %s + +#include "clad/Differentiator/Differentiator.h" + +// CHECK: warning: function 'printf' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives' +// CHECK: note: fallback to numerical differentiation is disabled by the 'CLAD_NO_NUM_DIFF' macro; considering 'printf' as 0 + +template +double fn1(double x, T... y) { + return (x * ... * y); +} + +double fn2(double x, double y) { + printf("x is %f, y is %f\n", x, y); + return fn1(x, y); +} + +// CHECK: double fn2_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: printf("x is %f, y is %f\n", x, y); +// CHECK-NEXT: clad::ValueAndPushforward _t0 = fn1_pushforward(x, y, _d_x, _d_y); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + +double fn3(double x, double y) { + return fn1(x, y, y); +} + +// CHECK: double fn3_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = fn1_pushforward(x, y, y, _d_x, _d_y, _d_y); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + +int main() { + auto d_fn2 = clad::differentiate(fn2, "x"); + printf("{%.2f}\n", d_fn2.execute(1.0, 2.0)); // CHECK-EXEC: x is 1.000000, y is 2.000000 + // CHECK-EXEC: {2.00} + auto d_fn3 = clad::differentiate(fn3, "x"); + printf("{%.2f}\n", d_fn3.execute(1.0, 2.0)); // CHECK-EXEC: {4.00} + return 0; +} + +// CHECK: clad::ValueAndPushforward fn1_pushforward(double x, double y, double _d_x, double _d_y) { +// CHECK-NEXT: return {x * y, _d_x * y + x * _d_y}; +// CHECK-NEXT:} + +// CHECK: clad::ValueAndPushforward fn1_pushforward(double x, double y, double y2, double _d_x, double _d_y, double _d_y2) { +// CHECK-NEXT: double _t0 = x * y; +// CHECK-NEXT: return {_t0 * y2, (_d_x * y + x * _d_y) * y2 + _t0 * _d_y2}; +// CHECK-NEXT:} \ No newline at end of file