-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
374 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
Guides | ||
============= | ||
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
guides/introduction.rst | ||
guides/promotion.rst | ||
guides/prelude.rst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
def name(x): | ||
if x == ("b", 1): | ||
return "b" | ||
|
||
return f"{x[0]}{x[1]}" | ||
|
||
def promote(a, b): | ||
x = a[0] | ||
y = b[0] | ||
|
||
if x == y: | ||
return (x, max(a[1], b[1])) | ||
|
||
if x == "b": | ||
return b | ||
|
||
if y == "b": | ||
return a | ||
|
||
if x in ("f", "bf") and y in ("i", "u"): | ||
return a | ||
|
||
if y in ("f", "bf") and x in ("i", "u"): | ||
return b | ||
|
||
if x in ("f", "bf") and y in ("f", "bf"): | ||
if a[1] > b[1]: | ||
return a | ||
elif b[1] > a[1]: | ||
return b | ||
else: | ||
return ("f", a[1] * 2) | ||
|
||
return None | ||
|
||
|
||
if __name__ == "__main__": | ||
types = [("b", 1)] | ||
types += [("i", n) for n in [8, 16, 32, 64]] | ||
types += [("u", n) for n in [8, 16, 32, 64]] | ||
types += [("f", n) for n in [8, 16, 32, 64]] | ||
|
||
types.insert(types.index(("f", 32)), ("bf", 16)) | ||
|
||
lines = [] | ||
|
||
header = [""] | ||
for a in types: | ||
header.append(f"**{name(a)}**") | ||
lines.append(",".join(header)) | ||
|
||
for a in types: | ||
line = [f"**{name(a)}**"] | ||
|
||
for b in types: | ||
c = promote(a, b) | ||
line.append(name(c) if c else "x") | ||
|
||
lines.append(",".join(line)) | ||
|
||
with open("promotion_table.csv", "w") as f: | ||
f.write("\n".join(lines)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
Getting started | ||
=============== | ||
|
||
Kernel Float is a header-only library that makes it easy to work with vector types and low-precision floating-point types, mainly focusing on CUDA kernel code. | ||
|
||
Installation | ||
------------ | ||
|
||
The easiest way to use the library is get the single header file from github: | ||
|
||
```bash | ||
wget https://raw.githubusercontent.com/KernelTuner/kernel_float/main/single_include/kernel_float.h | ||
``` | ||
|
||
Next, include this file into your program. | ||
It is conventient to define a namespace alias `kf` to shorten the full name `kernel_float`. | ||
|
||
|
||
```C++ | ||
#include "kernel_float.h" | ||
namespace kf = kernel_float; | ||
``` | ||
Example C++ code | ||
---------------- | ||
Kernel Float essentially offers a single data-type `kernel_float::vec<T, N>` that stores `N` elements of type `T`. | ||
This type can be initialized normally using list-initialization (e.g., `{a, b, c}`) and elements can be accessed using the `[]` operator. | ||
Operation overload is available to perform binary operations (such as `+`, `*`, and `&`), where the optimal intrinsic for the available types is selected automatically. | ||
Many mathetical functions (like `log`, `sin`, `cos`) are also available, see the [API reference](../api) for the full list of functions. | ||
In some cases, certain operations might not be natively supported by the platform for the some floating-point type. | ||
In these cases, Kernel Float falls back to performing the operations in 32 bit precision. | ||
The code below shows a very simple example of how to use Kernel Float: | ||
```C++ | ||
#include "kernel_float.h" | ||
namespace kf = kernel_float; | ||
int main() { | ||
using Type = float; | ||
const int N = 8; | ||
kf::vec<int, N> i = kf::range<int, N>(); | ||
kf::vec<Type, N> x = kf::cast<Type>(i); | ||
kf::vec<Type, N> y = x * kf::sin(x); | ||
Type result = kf::sum(y); | ||
printf("result=%f", double(result)); | ||
return EXIT_SUCCESS; | ||
} | ||
``` | ||
|
||
Notice how easy it would be to change the floating-point type `Type` or the vector length `N` without affecting the rest of the code. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
Using `kernel_float::prelude` | ||
=== | ||
|
||
When working with Kernel Float, you'll find that you need to prefix every function and type with the `kernel_float::...` prefix. | ||
This can be a bit cumbersome. | ||
It's strongly discouraged not to dump the entire `kernel_float` namespace into the global namespace (with `using namespace kernel_float`) since | ||
many symbols in Kernel Float may clash with global symbols, causing conflicts and issues. | ||
|
||
To work around this, the library provides a handy `kernel_float::prelude` namespace. This namespace contains a variety of useful type and function aliases that won't conflict with global symbols. | ||
|
||
To make use of it, use the following code: | ||
|
||
|
||
```C++ | ||
#include "kernel_float.h" | ||
using namespace kernel_float::prelude; | ||
|
||
// You can now use aliases like `kf`, `kvec`, `kint`, etc. | ||
``` | ||
The prelude defines many aliases, include the following: | ||
| Prelude name | Full name | | ||
|---|---| | ||
| `kf` | `kernel_float` | | ||
| `kvec<T, N>` | `kernel_float::vec<T, N>` | | ||
| `into_kvec(v)` | `kernel_float::into_vec(v)` | | ||
| `make_kvec(a, b, ...)` | `kernel_float::make_vec(a, b, ...)` | | ||
| `kvec2<T>`, `kvec3<T>`, ... | `kernel_float::vec<T, 2>`, `kernel_float::vec<T, 3>`, ... | | ||
| `kint<N>` | `kernel_float::vec<int, N>` | | ||
| `kint2`, `kint3`, ... | `kernel_float::vec<int, 2>`, `kernel_float::vec<int, 3>`, ... | | ||
| `klong<N>` | `kernel_float::vec<long, N>` | | ||
| `klong2`, `klong3`, ... | `kernel_float::vec<long, 2>`, `kernel_float::vec<long, 3>`, ... | | ||
| `kbfloat16x<N>` | `kernel_float::vec<bfloat16, N>` | | ||
| `kbfloat16x2`, `kbfloat16x3`, ... | `kernel_float::vec<bfloat16, 2>`, `kernel_float::vec<bfloat16, 3>`, ... | | ||
| `khalf<N>` | `kernel_half::vec<half, N>` | | ||
| `khalf2`, `khalf3`, ... | `kernel_half::vec<half, 2>`, `kernel_half::vec<half, 3>`, ... | | ||
| `kfloat<N>` | `kernel_float::vec<float, N>` | | ||
| `kfloat2`, `kfloat3`, ... | `kernel_float::vec<float, 2>`, `kernel_float::vec<float, 3>`, ... | | ||
| `kdouble<N>` | `kernel_float::vec<double, N>` | | ||
| `kdouble2`, `kdouble3`, ... | `kernel_float::vec<double, 2>`, `kernel_float::vec<double, 3>`, ... | | ||
| ... | ... | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
Type Promotion | ||
============== | ||
|
||
For operations that involve two input arguments (or more), ``kernel_float`` will first convert the inputs into a common type before applying the operation. | ||
For example, when adding ``vec<int, N>`` to a ``vec<float, N>``, both arguments must first be converted into a ``vec<float, N>``. | ||
|
||
This procedure is called "type promotion" and is implemented as follows. | ||
First, all arguments are converted into a vector by calling ``into_vec``. | ||
Next, all arguments must have length ``N`` or length ``1`` and vectors of length ``1`` are resized to become length ``N``. | ||
Finally, the vector element types are promoted into a common type. | ||
|
||
The rules for element type promotion in ``kernel_float`` are slightly different than in regular C++. | ||
In short, for two element types ``T`` and ``U``, the promotion rules can be summarized as follows: | ||
|
||
* If one of the types is ``bool``, the result is the other type. | ||
* If one type is a floating-point type and the other is a signed or unsigned integer, the result is the floating-point type. | ||
* If both types are floating-point types, the result is the largest of the two types. An exception here is combining ``half`` and ``bfloat16``, which results in ``float``. | ||
* If both types are integer types of the same signedness, the result is the largest of the two types. | ||
* Combining a signed integer and unsigned integer type is not allowed. | ||
|
||
Overview | ||
-------- | ||
|
||
The type promotion rules are shown in the table below. | ||
The labels are as follows: | ||
|
||
* ``b``: boolean | ||
* ``iN``: signed integer of ``N`` bits (e.g., ``int``, ``long``) | ||
* ``uN``: unsigned integer of ``N`` bits (e.g., ``unsigned int``, ``size_t``) | ||
* ``fN``: floating-point type of ``N`` bits (e.g., ``float``, ``double``) | ||
* ``bf16``: bfloat16 floating-point format. | ||
|
||
.. csv-table:: Type Promotion Rules. | ||
:file: promotion_table.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
,**b**,**i8**,**i16**,**i32**,**i64**,**u8**,**u16**,**u32**,**u64**,**f8**,**f16**,**bf16**,**f32**,**f64** | ||
**b**,b,i8,i16,i32,i64,u8,u16,u32,u64,f8,f16,bf16,f32,f64 | ||
**i8**,i8,i8,i16,i32,i64,x,x,x,x,f8,f16,bf16,f32,f64 | ||
**i16**,i16,i16,i16,i32,i64,x,x,x,x,f8,f16,bf16,f32,f64 | ||
**i32**,i32,i32,i32,i32,i64,x,x,x,x,f8,f16,bf16,f32,f64 | ||
**i64**,i64,i64,i64,i64,i64,x,x,x,x,f8,f16,bf16,f32,f64 | ||
**u8**,u8,x,x,x,x,u8,u16,u32,u64,f8,f16,bf16,f32,f64 | ||
**u16**,u16,x,x,x,x,u16,u16,u32,u64,f8,f16,bf16,f32,f64 | ||
**u32**,u32,x,x,x,x,u32,u32,u32,u64,f8,f16,bf16,f32,f64 | ||
**u64**,u64,x,x,x,x,u64,u64,u64,u64,f8,f16,bf16,f32,f64 | ||
**f8**,f8,f8,f8,f8,f8,f8,f8,f8,f8,f8,f16,bf16,f32,f64 | ||
**f16**,f16,f16,f16,f16,f16,f16,f16,f16,f16,f16,f16,f32,f32,f64 | ||
**bf16**,bf16,bf16,bf16,bf16,bf16,bf16,bf16,bf16,bf16,bf16,f32,bf16,f32,f64 | ||
**f32**,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64 | ||
**f64**,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef KERNEL_FLOAT_FP8_H | ||
#define KERNEL_FLOAT_FP8_H | ||
|
||
#include "macros.h" | ||
|
||
#if KERNEL_FLOAT_FP8_AVAILABLE | ||
#include <cuda_fp8.h> | ||
|
||
#include "vector.h" | ||
|
||
namespace kernel_float { | ||
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_fp8_e4m3) | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_fp8_e4m3) | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_fp8_e4m3) | ||
|
||
KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_fp8_e5m2) | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_fp8_e5m2) | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_fp8_e5m2) | ||
} // namespace kernel_float | ||
|
||
#if KERNEL_FLOAT_FP16_AVAILABLE | ||
#include "fp16.h" | ||
|
||
namespace kernel_float { | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3) | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2) | ||
} // namespace kernel_float | ||
#endif // KERNEL_FLOAT_FP16_AVAILABLE | ||
|
||
#if KERNEL_FLOAT_BF16_AVAILABLE | ||
#include "bf16.h" | ||
|
||
namespace kernel_float { | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3) | ||
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2) | ||
} // namespace kernel_float | ||
#endif // KERNEL_FLOAT_BF16_AVAILABLE | ||
|
||
#endif // KERNEL_FLOAT_FP8_AVAILABLE | ||
#endif // KERNEL_FLOAT_FP8_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.