Skip to content

Commit

Permalink
Merge branch 'master' into arial/initializer-list
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhangNV committed Sep 19, 2024
2 parents e95f76f + 3861be7 commit 9333fe1
Show file tree
Hide file tree
Showing 77 changed files with 2,948 additions and 1,355 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ option(SLANG_ENABLE_TESTS "Enable test targets, some tests may require SLANG_ENA
option(SLANG_ENABLE_EXAMPLES "Enable example targets, requires SLANG_ENABLE_GFX" ON)
option(SLANG_ENABLE_REPLAYER "Enable slang-replay tool" ON)

option(SLANG_GITHUB_TOKEN "Use a given token value for accessing Github REST API" "")

enum_option(
SLANG_LIB_TYPE
# Default
Expand Down Expand Up @@ -136,7 +138,7 @@ enum_option(
if(SLANG_SLANG_LLVM_FLAVOR MATCHES FETCH_BINARY)
# If the user didn't specify a URL, find the best one now
if(NOT SLANG_SLANG_LLVM_BINARY_URL)
get_best_slang_binary_release_url(url)
get_best_slang_binary_release_url("${SLANG_GITHUB_TOKEN}" url)
if(NOT DEFINED url)
message(FATAL_ERROR "Unable to find binary release for slang-llvm, please set a different SLANG_SLANG_LLVM_FLAVOR or set SLANG_SLANG_LLVM_BINARY_URL manually")
endif()
Expand Down
28 changes: 23 additions & 5 deletions cmake/GitHubRelease.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function(check_release_and_get_latest owner repo version os arch out_var)
function(check_release_and_get_latest owner repo version os arch github_token out_var)
# Construct the URL for the specified version's release API endpoint
set(version_url "https://api.github.com/repos/${owner}/${repo}/releases/tags/v${version}")

Expand All @@ -17,8 +17,22 @@ function(check_release_and_get_latest owner repo version os arch out_var)
set(${found_var} "${found}" PARENT_SCOPE)
endfunction()

# Download the specified release info from GitHub
file(DOWNLOAD "${version_url}" "${json_output_file}" STATUS download_statuses)
# Prepare download arguments
set(download_args
"${version_url}"
"${json_output_file}"
STATUS download_statuses
)

if(github_token)
# Add authorization header if token is provided
list(APPEND download_args HTTPHEADER "Authorization: token ${github_token}")
endif()

# Perform the download
file(DOWNLOAD ${download_args})

# Check if the downloading was successful
list(GET download_statuses 0 status_code)
if(status_code EQUAL 0)
file(READ "${json_output_file}" json_content)
Expand All @@ -34,6 +48,10 @@ function(check_release_and_get_latest owner repo version os arch out_var)
message(WARNING "Failed to find ${desired_zip} in release assets for ${version} from ${version_url}\nFalling back to latest version if it differs")
else()
message(WARNING "Failed to download release info for version ${version} from ${version_url}\nFalling back to latest version if it differs")

if(status_code EQUAL 22)
message(WARNING "If API rate limit is exceeded, Github allows a higher limit when you use token. Try a cmake option -DSLANG_GITHUB_TOKEN=your_token_here")
endif()
endif()


Expand Down Expand Up @@ -69,7 +87,7 @@ function(check_release_and_get_latest owner repo version os arch out_var)
endif()
endfunction()

function(get_best_slang_binary_release_url out_var)
function(get_best_slang_binary_release_url github_token out_var)
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64|AMD64")
set(arch "x86_64")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|ARM64|arm64")
Expand All @@ -93,7 +111,7 @@ function(get_best_slang_binary_release_url out_var)
set(owner "shader-slang")
set(repo "slang")

check_release_and_get_latest(${owner} ${repo} ${SLANG_VERSION_NUMERIC} ${os} ${arch} release_version)
check_release_and_get_latest(${owner} ${repo} ${SLANG_VERSION_NUMERIC} ${os} ${arch} "${github_token}" release_version)
if(DEFINED release_version)
set(${out_var} "https://github.com/${owner}/${repo}/releases/download/v${release_version}/slang-${release_version}-${os}-${arch}.zip" PARENT_SCOPE)
endif()
Expand Down
3 changes: 2 additions & 1 deletion docs/user-guide/06-interfaces-generics.md
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,8 @@ Slang supports the following builtin interfaces:
- `IInteger`, represents a logical integer that supports both `IArithmetic` and `ILogical` operations. Implemented by all builtin integer scalar types.
- `IDifferentiable`, represents a value that is differentiable.
- `IFloat`, represents a logical float that supports both `IArithmetic`, `ILogical` and `IDifferentiable` operations. Also provides methods to convert to and from `float`. Implemented by all builtin floating-point scalar, vector and matrix types.
- `IArray<T>`, represents a logical array that supports retrieving an element of type `T` from an index. Implemented by array types, vectors and matrices.
- `IArray<T>`, represents a logical array that supports retrieving an element of type `T` from an index. Implemented by array types, vectors, matrices and `StructuredBuffer`.
- `IRWArray<T>`, represents a logical array whose elements are mutable. Implemented by array types, vectors, matrices, `RWStructuredBuffer` and `RasterizerOrderedStructuredBuffer`.
- `IFunc<TResult, TParams...>` represent a callable object (with `operator()`) that returns `TResult` and takes `TParams...` as argument.
- `IMutatingFunc<TResult, TParams...>`, similar to `IFunc`, but the `operator()` method is `[mutating]`.
- `IDifferentiableFunc<TResult, TParams...>`, similar to `IFunc`, but the `operator()` method is `[Differentiable]`.
Expand Down
81 changes: 46 additions & 35 deletions docs/user-guide/07-autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,13 @@ Where the backward propagation function $$\mathbb{B}[f_i]$$ takes as input the p

The higher order operator $$\mathbb{F}$$ and $$\mathbb{B}$$ represent the operations that converts an original or primal function $$f$$ to its forward or backward derivative propagation function. Slang's automatic differentiation feature provide built-in support for these operators to automatically generate the derivative propagation functions from a user defined primal function. The remaining documentation will discuss this feature from a programming language perspective.

## Differentiable Types
Slang will only generate differentiation code for values that has a *differentiable* type. A type is differentiable if it conforms to the built-in `IDifferentiable` interface. The definition of the `IDifferentiable` interface is:
## Differentiable Value Types
Slang will only generate differentiation code for values that has a *differentiable* type.
Differentiable types are defining through conformance to one of two built-in interfaces:
1. `IDifferentiable`: For value types (e.g. `float`, structs of value types, etc..)
2. `IDifferentiablePtrType`: For buffer, pointer & reference types that represent locations rather than values.

The `IDifferentiable` interface requires the following definitions (which can be auto-generated by the compiler for most scenarios)
```csharp
interface IDifferentiable
{
Expand All @@ -147,23 +152,38 @@ interface IDifferentiable
static Differential dzero();

static Differential dadd(Differential, Differential);

static Differential dmul(This, Differential);
}
```
As defined by the `IDifferentiable` interface, a differentiable type must have a `Differential` associated type that stores the derivative of the value. A further requirement is that the type of the second-order derivative must be the same `Differential` type. In another word, given a type `T`, `T.Differential` can be different from `T`, but `T.Differential.Differential` must equal to `T.Differential`.

In addition, a differentiable type must define the `zero` value of its derivative, and how to add and multiply derivative values.
In addition, a differentiable type must define the `zero` value of its derivative, and how to add two derivative values together. These function are used during reverse-mode auto-diff, to initialize and accumulate derivatives of the given type.

### Builtin Differentiable Types
By contrast, `IDifferentiablePtrType` only requires a `Differential` associated type which also conforms to `IDifferentiablePtrType`.
```csharp
interface IDifferentiablePtrType
{
associatedtype Differential : IDifferentiablePtrType;
where Differential.Differential == Differential;
}
```

Types should not conform to both `IDifferentiablePtrType` and `IDifferentiable`. Such cases will result in a compiler error.


### Builtin Differentiable Value Types
The following built-in types are differentiable:
- Scalars: `float`, `double` and `half`.
- Vector/Matrix: `vector` and `matrix` of `float`, `double` and `half` types.
- Arrays: `T[n]` is differentiable if `T` is differentiable.
- Tuples: `Tuple<each T>` is differentiable if `T` is differentiable.

### Builtin Differentiable Ptr Types
There are currently no built-in types that conform to `IDifferentiablePtrType`

### User Defined Differentiable Types

The user can make any `struct` types differentiable by implementing the `IDifferentiable` interface on the type. The requirements from `IDifferentiable` interface can be fulfilled automatically or manually.
The user can make any `struct` types differentiable by implementing either `IDifferentiable` & `IDifferentiablePtrType` interface on the type.
The requirements from `IDifferentiable` interface can be fulfilled automatically or manually, though `IDifferentiablePtrType` currently requires the user to provide the `Differential` type.

#### Automatic Fulfillment of `IDifferentiable` Requirements
Assume the user has defined the following type:
Expand Down Expand Up @@ -191,7 +211,7 @@ Note that this code does not provide any explicit implementation of the `IDiffer
1. A new type is generated that stores the `Differential` of all differentiable fields. This new type itself will conform to the `IDifferentiable` interface, and it will be used to satisfy the `Differential` associated type requirement.
2. Each differential field will be associated to its corresponding field in the newly synthesized `Differential` type.
3. The `zero` value of the differential type is made from the `zero` value of each field in the differential type.
4. The `dadd` and `dmul` methods simply perform `dadd` and `dmul` operations on each field.
4. The `dadd` method invokes the `dadd` operations for each field whose type conforms to `IDifferentiable`.
5. If the synthesized `Differential` type contains exactly the same fields as the original type, and the type of each field is the same as the original field type, then the original type itself will be used as the `Differential` type instead of creating a new type to satisfy the `Differential` associated type requirement. This means that all the synthesized `Differential` type use itself to meet its own `IDifferentiable` requirements.

#### Manual Fulfillment of `IDifferentiable` Requirements
Expand Down Expand Up @@ -235,15 +255,6 @@ struct MyRay : IDifferentiable
result.d_dir = v1.d_dir + v2.d_dir;
return result;
}

// Define the multiply operation of a primal value and a derivative value.
static MyRayDifferential dmul(MyRay p, MyRayDifferential d)
{
MyRayDifferential result;
result.d_origin = p.origin * d.d_origin;
result.d_dir = p.dir * d.d_dir;
return result;
}
}
```

Expand Down Expand Up @@ -284,14 +295,6 @@ struct MyRayDifferential : IDifferentiable
result.d_dir = v1.d_dir + v2.d_dir;
return result;
}

static MyRayDifferential dmul(MyRayDifferential p, MyRayDifferential d)
{
MyRayDifferential result;
result.d_origin = p.d_origin * d.d_origin;
result.d_dir = p.d_dir * d.d_dir;
return result;
}
}
```
In this specific case, the automatically generated `IDifferentiable` implementation will be exactly the same as the manually written code listed above.
Expand All @@ -303,17 +306,19 @@ Functions in Slang can be marked as forward-differentiable or backward-different

A forward derivative propagation function computes the derivative of the result value with regard to a specific set of input parameters.
Given an original function, the signature of its forward propagation function is determined using the following rules:
- If the return type `R` is differentiable, the forward propagation function will return `DifferentialPair<R>` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`.
- If a parameter has type `T` that is differentiable, it will be translated into a `DifferentialPair<T>` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters.
- If the return type `R` implements `IDifferentiable` the forward propagation function will return a corresponding `DifferentialPair<R>` that consists of both the computed original result value and the (partial) derivative of the result value. Otherwise, the return type is kept unmodified as `R`.
- If a parameter has type `T` that implements `IDifferentiable`, it will be translated into a `DifferentialPair<T>` parameter in the derivative function, where the differential component of the `DifferentialPair` holds the initial derivatives of each parameter with regard to their upstream parameters.
- If a parameter has type `T` that implements `IDifferentiablePtrType`, it will be translated into a `DifferentialPtrPair<T>` parameter where the differential component references the differential location or buffer.
- All parameter directions are unchanged. For example, an `out` parameter in the original function will remain an `out` parameter in the derivative function.
- Differentiable methods cannot have a type implementing `IDifferentiablePtrType` as an `out` or `inout` parameter, or a return type. Types implementing `IDifferentiablePtrType` can only be used for input parameters to a differentiable method. Marking such a method as `[Differentiable]` will result in a compile-time diagnostic error.

For example, given original function:
```csharp
R original(T0 p0, inout T1 p1, T2 p2);
R original(T0 p0, inout T1 p1, T2 p2, T3 p3);
```
Where `R`, `T0`, and `T1` is differentiable and `T2` is non-differentiable, the forward derivative function will have the following signature:
Where `R`, `T0`, `T1 : IDifferentiable`, `T2` is non-differentiable, and `T3 : IDifferentiablePtrType`, the forward derivative function will have the following signature:
```csharp
DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T1> p1, T2 p2);
DifferentialPair<R> derivative(DifferentialPair<T0> p0, inout DifferentialPair<T1> p1, T2 p2, DifferentialPtrPair<T3> p3);
```

This forward propagation function takes the initial primal value of `p0` in `p0.p`, and the partial derivative of `p0` with regard to some upstream parameter in `p0.d`. It takes the initial primal and derivative values of `p1` and updates `p1` to hold the newly computed value and propagated derivative. Since `p2` is not differentiable, it remains unchanged.
Expand All @@ -327,10 +332,11 @@ struct DifferentialPair<T : IDifferentiable> : IDifferentiable
property T.Differential d {get;}
static Differential dzero();
static Differential dadd(Differential a, Differential b);
static Differential dmul(This a, Differential b);
}
```

For ptr-types, there is a corresponding built-in `DifferentialPtrPair<T : IDifferentiablePtrPair>` that does not have the `dzero` or `dadd` methods.

### Automatic Implementation of Forward Derivative Functions

A function can be made forward-differentiable with a `[ForwardDifferentiable]` attribute. This attribute will cause the compiler to automatically implement the forward propagation function. The syntax for using `[ForwardDifferentiable]` is:
Expand Down Expand Up @@ -392,23 +398,25 @@ Given an original function `f`, the general rule for determining the signature o

More specifically, the signature of its backward propagation function is determined using the following rules:
- A backward propagation function always returns `void`.
- A differentiable `in` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input.
- A differentiable `out` parameter of type `T` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value.
- A differentiable `inout` parameter of type `T` will become an `inout DifferentialPair<T>` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated.
- A differentiable `in` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair<T>` parameter, where the original value part of the differential pair contains the original value of the parameter to pass into the back-prop function. The original value will not be overwritten by the backward propagation function. The propagated derivative will be written to the derivative part of the differential pair after the backward propagation function returns. The initial derivative value of the pair is ignored as input.
- A differentiable `out` parameter of type `T : IDifferentiable` will become an `in T.Differential` parameter, carrying the partial derivative of some downstream term with regard to the return value.
- A differentiable `inout` parameter of type `T : IDifferentiable` will become an `inout DifferentialPair<T>` parameter, where the original value of the argument, along with the downstream partial derivative with regard to the argument is passed as input to the backward propagation function as the original and derivative part of the pair. The propagated derivative with regard to this input parameter will be written back and replace the derivative part of the pair. The primal value part of the parameter will *not* be updated.
- A differentiable return value of type `R` will become an additional `in R.Differential` parameter at the end of the backward propagation function parameter list, carrying the result derivative of a downstream term with regard to the return value of the original function.
- A non-differentiable return value of type `NDR` will be dropped.
- A non-differentiable `in` parameter of type `ND` will remain unchanged in the backward propagation function.
- A non-differentiable `out` parameter of type `ND` will be removed from the parameter list of the backward propagation function.
- A non-differentiable `inout` parameter of type `ND` will become an `in ND` parameter.
- Types implemented `IDifferentiablePtrType` work the same was as the forward-mode case. They can only be used with `in` parameters, and are converted into `DifferentialPtrPair` types. Their directions are not affected.

For example consider the following original function:
```csharp
struct T : IDifferentiable {...}
struct R : IDifferentiable {...}
struct P : IDifferentiablePtrType {...}
struct ND {} // Non differentiable
[Differentiable]
R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5);
R original(T p0, out T p1, inout T p2, ND p3, out ND p4, inout ND p5, P p6);
```
The signature of its backward propagation function is:
```csharp
Expand All @@ -418,6 +426,7 @@ void back_prop(
inout DifferentialPair<T> p2,
ND p3,
ND p5,
DifferentialPtrPair<P> p6,
R.Differential dResult);
```
Note that although `p2` is still `inout` in the backward propagation function, the backward propagation function will only write propagated derivative to `p2.d` and will not modify `p2.p`.
Expand Down Expand Up @@ -447,6 +456,7 @@ void back_prop(
inout DifferentialPair<T> p2,
ND p3,
ND p5,
DifferentialPtrPair<P> p6,
R.Differential dResult)
{
...
Expand All @@ -468,6 +478,7 @@ void back_prop(
inout DifferentialPair<T> p2,
ND p3,
ND p5,
DifferentialPtrPair<P> p6,
R.Differential dResult)
{
...
Expand Down
Loading

0 comments on commit 9333fe1

Please sign in to comment.