Differentiable Programming with Slang #2462
Closed
saipraveenb25
started this conversation in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Differentiable Programming with Slang
This page is a living documentation of the currently available auto-diff features in Slang, and how to use them. We will be updating this as new features are added/removed
Forward-mode auto-diff with
__fwd_diff(fn)
You can invoke the forward-mode derivative version of an arbitrary function by wrapping it with
__fwd_diff()
.As an example, lets take a simple 1-input, 1-output function that computes the square
The forward-mode derivative computes the differential of the output, given the differentials of it's inputs.
To obtain the forward-mode derivative of
sqr
, we first marksqr
as differentiable by adding a[ForwardDifferentiable]
attribute to the signature, and then invoke__fwd_diff(sqr)
.The new function requires differentials in addition to the original inputs. In Slang, we enable this by transforming each input into a pair of itself and its differential.
We achieve this through the generic struct
DifferentialPair<T>
which holds both the orginal value and the differential. The former can be accessed using.p()
or.getPrimal()
, and the latter can be accessed through.d()
or.getDifferential()
The synthesized function looks like the following:
Here's a complete example:
Basic Differentiable Types
Currently, the following types will be detected as being automatically differentiable:
float
float2
,float3
, andfloat4
(but not for general N-sized vectors)More type support is being actively added.
Custom Differentiable Types
Define a type as differentiable by extending the
IDifferentiable
interface. In most cases, it is sufficient to just write:And the compiler will automatically synthesis all the code required to implement
IDifferentiable
interface.The definition of
IDifferentiable
interface has four requirements:The
Differential
associatedtype specifies the type that represents the result of differentiation of value of this type.The
dzero
method defines the 0 value of the differential type.The
dadd
method defines the add operation on two differential values.The
dmul
method defines the inner product of a primal value and differential value.In order to use a member of a differential struct type, that member must be marked with a
[DerivativeMember]
attribute. Members without[DerivativeMember]
attribute will be treated as they are non-differentiable. However if the user does not specify theDifferential
associatedtype, the compiler will automatically synthesize one that includes all the differentiable members, and automatically insert[DerivativeMember]
attribute.The user can make use of automatic
IDifferentiable
conformance synthesis by writing the following:And the compiler will automatically synthesize necessary definitions within the struct and the result is equivalent to the user writing the following code:
Custom Forward-mode functions
Sometimes, it is desirable to have a custom hand-written derivative for a function, rather than synthesizing the derivative
The attribute
[ForwardDerivative(<derivative-fn-name>)]
can be used to acheive this.A complete example of this:
Custom derivative functions can also be provided with the
[ForwardDerivativeOf]
attribute, which works the opposite way as[ForwardDerivative]
. Instead of decorating the original function,[ForwardDerivativeOf]
attribute decorates the derivative function and associate it with an original function. The previous example can also be written as following:Calling Non-Differentiable Functions from a Differnetiable Function
Calling non-differentiable functions from a differentiable function is allowed, as long as the result of the non-differentiable function call does not directly contribute to any differentiable result (control-flow dependence doesn't count as contributing).
It is a compile-time error if the result of a non-differentiable function call is used to produce the a differentiable result value.
For example:
However, this compile error can be supressed with the
no_diff
keyword to clarify that the non-differentiable call is intentional.Standard library functions
Most floating point math functions has builtin derivative implementation. Calls to these functions will be differentiated automatically. The following code is an example of calling stdlib function in a differentiable function.
The following stdlib functions have derivative implementations:
log
,sin
,cos
,exp
,pow
,abs
,sqrt
,max
,min
,dot
.You can add more stdlib support by adding definitions using the
[ForwardDerivativeOf]
attribute. For example, the following code defines the derivative implementation forsin
:All locations that sees the definition of
__d_sin
will treatsin
as a differentiable function.Dos and Don'ts in Differentiable Functions
Not all statements are supported inside a differentiable function. Here's an incomplete list of what is and isn't supported.
Okay to use inside differentiable function:
[ForwardDifferentiable]
or[ForwardDerivative]
attribute)[DerivativeMember]
attribute either explictly by the user or implicitly by the compiler through automaticIDifferentiable
conformance synthesis)for
loops as well asif-else
conditional blocks.no_diff
keyword to clarify the intention when the result of a non-differentiable function contributes to a differentiable result value.Not okay to use inside differentiable functions (error or undefined behaviour):
TraceRay
orTraceRayInline
Beta Was this translation helpful? Give feedback.
All reactions