Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of variadic functions as call exprs in Forward mode #1246

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,28 @@ QualType BaseForwardModeVisitor::ComputeDerivativeFunctionType() {
void BaseForwardModeVisitor::SetupDerivativeParameters(
llvm::SmallVectorImpl<ParmVarDecl*>& params) {
const FunctionDecl* FD = m_DiffReq.Function;
for (ParmVarDecl* PVD : FD->parameters()) {
const FunctionDecl* FDPattern = nullptr;
unsigned FDPatternNumParams = 0;
if (FD->isTemplateInstantiation()) {
FDPattern = FD->getTemplateInstantiationPattern();
FDPatternNumParams = FDPattern->getNumParams();
}
for (unsigned i = 0, n = FD->getNumParams(); i < n; ++i) {
const ParmVarDecl* PVD = FD->getParamDecl(i);
IdentifierInfo* PVDII = PVD->getIdentifier();
// Implicitly created special member functions have no parameter names.
if (!PVD->getDeclName())
PVDII = CreateUniqueIdentifier("param");

if (FDPattern) {
const ParmVarDecl* OrigPVD =
i >= FDPatternNumParams
? FDPattern->getParamDecl(FDPatternNumParams - 1)
: FDPattern->getParamDecl(i);
if (OrigPVD->isParameterPack())
PVDII = CreateUniqueIdentifier(PVDII->getName());
}

auto* newPVD = CloneParmVarDecl(PVD, PVDII,
/*pushOnScopeChains=*/true,
/*cloneDefaultArg=*/false);
Expand All @@ -295,7 +311,8 @@ void BaseForwardModeVisitor::SetupDerivativeParameters(
if (PVD == m_IndependentVar)
m_IndependentVar = newPVD;

if (!PVD->getDeclName()) // We can't use lookup-based replacements
if (PVD->getDeclName() !=
newPVD->getDeclName()) // We can't use lookup-based replacements
m_DeclReplacements[PVD] = newPVD;

params.push_back(newPVD);
Expand Down Expand Up @@ -1104,6 +1121,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
skipFirstArg = true;

// For f(g(x)) = f'(x) * g'(x)
std::size_t numParams = FD->getNumParams();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: no header providing "std::size_t" is directly included [misc-include-cleaner]

  std::size_t numParams = FD->getNumParams();
       ^

Expr* Multiplier = nullptr;
for (size_t i = skipFirstArg, e = CE->getNumArgs(); i < e; ++i) {
const Expr* arg = CE->getArg(i);
Expand All @@ -1112,7 +1130,11 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// If original argument is an RValue and function expects an RValue
// parameter, then convert the cloned argument and the corresponding
// derivative to RValue if they are not RValue.
QualType paramType = FD->getParamDecl(i - skipFirstArg)->getType();
QualType paramType;
if (FD->isVariadic() && (i - skipFirstArg) >= numParams)
paramType = CE->getArg(i - skipFirstArg)->getType();
else
paramType = FD->getParamDecl(i - skipFirstArg)->getType();
if (utils::IsRValue(arg) && paramType->isRValueReferenceType()) {
if (!utils::IsRValue(argDiff.getExpr())) {
Expr* castE = utils::BuildStaticCastToRValue(m_Sema, argDiff.getExpr());
Expand Down
54 changes: 54 additions & 0 deletions test/ForwardMode/VariadicCall.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: %cladclang -std=c++17 %s -I%S/../../include -oVariadicCall.out 2>&1 | %filecheck %s
// RUN: ./VariadicCall.out | %filecheck_exec %s

#include "clad/Differentiator/Differentiator.h"

// CHECK: warning: function 'printf' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add -Xclang -verify and use the syntax of // expected-warning ...

// CHECK: note: fallback to numerical differentiation is disabled by the 'CLAD_NO_NUM_DIFF' macro; considering 'printf' as 0

template <typename... T>
double fn1(double x, T... y) {
return (x * ... * y);
}

double fn2(double x, double y) {
printf("x is %f, y is %f\n", x, y);
return fn1(x, y);
}

// CHECK: double fn2_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: printf("x is %f, y is %f\n", x, y);
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = fn1_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

double fn3(double x, double y) {
return fn1(x, y, y);
}

// CHECK: double fn3_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = fn1_pushforward(x, y, y, _d_x, _d_y, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

int main() {
auto d_fn2 = clad::differentiate(fn2, "x");
printf("{%.2f}\n", d_fn2.execute(1.0, 2.0)); // CHECK-EXEC: x is 1.000000, y is 2.000000
// CHECK-EXEC: {2.00}
auto d_fn3 = clad::differentiate(fn3, "x");
printf("{%.2f}\n", d_fn3.execute(1.0, 2.0)); // CHECK-EXEC: {4.00}
return 0;
}

// CHECK: clad::ValueAndPushforward<double, double> fn1_pushforward(double x, double y, double _d_x, double _d_y) {
// CHECK-NEXT: return {x * y, _d_x * y + x * _d_y};
// CHECK-NEXT:}

// CHECK: clad::ValueAndPushforward<double, double> fn1_pushforward(double x, double y, double y2, double _d_x, double _d_y, double _d_y2) {
// CHECK-NEXT: double _t0 = x * y;
// CHECK-NEXT: return {_t0 * y2, (_d_x * y + x * _d_y) * y2 + _t0 * _d_y2};
// CHECK-NEXT:}
Loading