Skip to content

Commit

Permalink
[AD] Add support for resolving custom derivatives where generic param…
Browse files Browse the repository at this point in the history
…eters can't be automatically inferred (#5630)

* [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred

* Fix failing tests

* Update custom-derivative-generic.slang
  • Loading branch information
saipraveenb25 authored Nov 22, 2024
1 parent 95125f2 commit 9913cfb
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 2 deletions.
76 changes: 75 additions & 1 deletion source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl(
SemanticsContext::ExprLocalScope scope;
auto ctx = visitor->withExprLocalScope(&scope);
auto subVisitor = SemanticsVisitor(ctx);
auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);

auto exprToCheck = attr->funcExpr;

// If this is a generic, we want to wrap the call to the derivative method
// with the generic parameters of the source.
//
if (as<GenericDecl>(funcDecl->parentDecl) && !as<GenericAppExpr>(attr->funcExpr))
{
auto genericDecl = as<GenericDecl>(funcDecl->parentDecl);
auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl);
auto appExpr = ctx.getASTBuilder()->create<GenericAppExpr>();

Index count = 0;
for (auto member : genericDecl->members)
{
if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member) ||
as<GenericTypePackParamDecl>(member))
count++;
}

appExpr->functionExpr = attr->funcExpr;

for (auto arg : substArgs)
{
if (count == 0)
break;

if (auto declRefType = as<DeclRefType>(arg))
{
auto baseTypeExpr = ctx.getASTBuilder()->create<SharedTypeExpr>();
baseTypeExpr->base.type = declRefType;
auto baseTypeType = ctx.getASTBuilder()->getOrCreate<TypeType>(declRefType);
baseTypeExpr->type.type = baseTypeType;

appExpr->arguments.add(baseTypeExpr);
}
else if (auto genericValParam = as<GenericParamIntVal>(arg))
{
auto declRef = genericValParam->getDeclRef();
appExpr->arguments.add(
subVisitor
.ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr));
}
else
{
SLANG_UNEXPECTED("Unhandled substitution arg type");
}

count--;
}

exprToCheck = appExpr;
}

auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx);
attr->funcExpr = checkedFuncExpr;
if (attr->args.getCount())
attr->args[0] = attr->funcExpr;
Expand Down Expand Up @@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl(
calleeDeclRef = calleeDeclRefExpr->declRef;

auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());

if (!calleeFunc)
{
// If we couldn't find a direct function, it might be a generic.
if (auto genericDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
{
calleeFunc = as<FunctionDeclBase>(genericDecl->inner);

if (as<ErrorType>(resolved->type.type))
{
// If we can't resolve a type, something went wrong. If we're working with a generic
// decl, the most likely cause is a failure of generic argument inference.
//
visitor->getSink()->diagnose(
derivativeOfAttr,
Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
}
}
}

if (!calleeFunc)
{
visitor->getSink()->diagnose(
Expand Down
57 changes: 57 additions & 0 deletions tests/autodiff/custom-derivative-enum-param.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type

enum MyEnum { A, B, C };

[BackwardDerivative(mDiff)]
float m<let M : MyEnum>(float x)
{
switch (M)
{
case MyEnum.A:
return x * x;
case MyEnum.B:
return x;
case MyEnum.C:
return 3 * x;
default:
return 0;
}
}

void mDiff<let M : MyEnum>(inout DifferentialPair<float> x, float dResult)
{
switch (M)
{
case MyEnum.A:
updateDiff(x, 2 * dResult * x.p);
break;
case MyEnum.B:
updateDiff(x, dResult);
break;
case MyEnum.C:
updateDiff(x, 3 * dResult);
break;
default:
updateDiff(x, 0);
break;
}
}

[Differentiable]
float test(float x)
{
return m<MyEnum.A>(x);
}

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
var a = diffPair(3.0);
__bwd_diff(test)(a, 1.0);
outputBuffer[dispatchThreadID.x] = a.d;
// CHECK: 6.0
}
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/diagnostics/custom-derivative-generic.slang
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ DifferentialPair<float> dd1(DifferentialPair<float> x)
}

// CHECK-DAG: {{.*}}(37): error 31151
[BackwardDerivative(f)]
[BackwardDerivativeOf(f)]
DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut)
{
var primal = x.p * x.p;
Expand Down

0 comments on commit 9913cfb

Please sign in to comment.