Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Type promotion for pointwise Ops #79

Merged
merged 7 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- name: unit_test-flag-gems
run: |
CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_unary_pointwise_ops.py &
CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_pointwise_type_promotion.py &
CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py &
CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py &
CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py &
Expand Down
31 changes: 28 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ In FlagGems, we provide automatic code generation that developers can use to con
Decorating the pointwise operator function with `pointwise_dynamic` can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code.

```python
@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")])
@triton.jit
def abs_func(x):
return tl.abs(x)
Expand All @@ -29,15 +29,19 @@ def abs_func(x):
By default, `pointwise_dynamic` treats all parameters as tensors, and by passing a list of boolean values to the parameter `is_tensor`, developers can specify which parameters are tensors and which are not. Additionally, developers can pass in `dtypes` to indicate the data types of non-tensor parameters, but this is not required. For example, in the following code, the `alpha` parameter is defined as a non-tensor floating point number, while the `x` and `y` parameters are defined as tensors.

```python
@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float])
@pointwise_dynamic(
is_tensor=[True, True, False],
dtypes=[None, None, float],
promotion_methods=[(0,"DEFAULT")]
)
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha
```

#### Output Data Type

By default, all output tensors have the same data type as the first input tensor, but it can also be customized by providing a list of data types to the parameter `output_dtypes`. For example, in the following code, the output tensor type is specified as `torch.bool`.
Furthermore, developers MUST provide promotion_methods to specify how type promotion should be handled for the operation to achieve the correct output type during computation.

```python
@pointwise_dynamic(output_dtypes=[torch.bool])
Expand All @@ -46,6 +50,27 @@ def ge(x, y):
return x > y
```

In `promotion_methods`, an `int` is used to indicate the position of the parameter requiring type promotion, while a `str` denotes the method of type promotion. The `str` corresponds to the following enumerated types:

```python
class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
DEFAULT = (0,)
NO_OPMATH = (1,)
INT_TO_FLOAT = (2,)
ALWAYS_BOOL = (3,)
COMPLEX_TO_FLOAT = (4,)
BOOL_TO_LONG = (5,)
```

Examples:

- `DEFAULT` :add
- `NO_OPMATH` : where, nextafter, cat
- `INT_TO_FLOAT` :sin
- `ALWAYS_BOOL` :eq
- `COMPLEX_TO_FLOAT` :abs
- `BOOL_TO_LONG` :pow

## Changelog

### v1.0
Expand Down
33 changes: 29 additions & 4 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库
在对位算子函数前装饰`pointwise_dynamic`,可以节省张量寻址、张量读写、并行分块、张量广播、动态维度、非连续存储等的手动处理。例如以下代码,开发者只需简单描述计算逻辑,即可生成灵活高效的Triton核函数与包装代码。

```python
@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")])
@triton.jit
def abs_func(x):
return tl.abs(x)
Expand All @@ -29,23 +29,48 @@ def abs_func(x):
在默认情况下,`pointwise_dynamic`将所有参数均处理为张量,而通过向参数`is_tensor`传递布尔值列表,开发者可以指定哪些参数是张量,哪些参数非张量。此外,开发者还可以传入`dtypes`说明非张量参数的数据类型,但这不是必要的。例如以下代码,将`alpha`参数定义为非张量的浮点数,而`x`和`y`参数定义为张量。

```python
@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float])
@pointwise_dynamic(
is_tensor=[True, True, False],
dtypes=[None, None, float],
promotion_methods=[(0,"DEFAULT")]
)
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha
```

#### 输出数据类型

在默认情况下,输出张量使用与首个输入张量相同的数据类型,但也可向参数`output_dtypes`传入数据类型组成的列表来指定。例如以下代码,指定输出张量类型为`torch.bool`。
此外,开发者必须传入 `promotion_methods` 来说明该 Op 在进行计算时应该如何进行`类型提升`以获得正确的输出类型

```python
@pointwise_dynamic(output_dtypes=[torch.bool])
@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")])
@triton.jit
def ge(x, y):
return x > y
```

`promotion_methods` 通过传入 `int` 来表示需要进行类型提升的参数位置, 通过传入 `str` 来表示类型提升的方式, `str` 对于以下枚举类型

```python
class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
DEFAULT = (0,)
NO_OPMATH = (1,)
INT_TO_FLOAT = (2,)
ALWAYS_BOOL = (3,)
COMPLEX_TO_FLOAT = (4,)
BOOL_TO_LONG = (5,)
```

举例:

- `DEFAULT` :add
- `NO_OPMATH` : where, nextafter, cat
- `INT_TO_FLOAT` :sin
- `ALWAYS_BOOL` :eq
- `COMPLEX_TO_FLOAT` :abs
- `BOOL_TO_LONG` :pow

## 更新日志

### v1.0
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/fused/gelu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def gelu_none_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
x_gelu = 0.5 * x_fp32 * (1 + tl.math.erf(x_fp32 * 0.7071067811))
return x_gelu * y


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def gelu_tanh_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/fused/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def silu_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, "COMPLEX_TO_FLOAT")])
@triton.jit
def abs_func(x):
return tl.abs(x)
Expand Down
10 changes: 7 additions & 3 deletions src/flag_gems/ops/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,23 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic(is_tensor=[True, True, False])
@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(is_tensor=[True, False, False])
@pointwise_dynamic(
is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
)
@triton.jit
def add_func_tensor_scalar(x, y, alpha):
return x + y * alpha


@pointwise_dynamic(is_tensor=[False, True, False])
@pointwise_dynamic(
is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
)
@triton.jit
def add_func_scalar_tensor(x, y, alpha):
return x + y * alpha
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/bitwise_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def bitwise_and_func(x, y):
return x & y
Expand All @@ -16,7 +16,7 @@ def bitwise_and_tensor(A, B):
return bitwise_and_func(A, B)


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def bitwise_and_func_scalar(x, y):
return x & y
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/bitwise_not.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
@triton.jit
def bitwise_not_func(x):
return ~x
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/bitwise_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def bitwise_or_func(x, y):
return x | y
Expand All @@ -16,7 +16,7 @@ def bitwise_or_tensor(A, B):
return bitwise_or_func(A, B)


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def bitwise_or_func_scalar(x, y):
return x | y
Expand Down
14 changes: 8 additions & 6 deletions src/flag_gems/ops/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")])
@triton.jit
def clamp_func_tensor(x, mini, maxi):
return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def clamp_func_min_tensor(x, mini):
return tl.maximum(mini, x.to(tl.float32))


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def clamp_func_max_tensor(x, maxi):
return tl.minimum(maxi, x.to(tl.float32))
Expand All @@ -36,19 +36,21 @@ def clamp_tensor(A, mini=None, maxi=None):
return clamp_func_tensor(A, mini, maxi)


@pointwise_dynamic(is_tensor=[True, False, False])
@pointwise_dynamic(
is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
)
@triton.jit
def clamp_func(x, mini, maxi):
return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def clamp_func_min(x, mini):
return tl.maximum(mini, x.to(tl.float32))


@pointwise_dynamic(is_tensor=[True, False])
@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def clamp_func_max(x, maxi):
return tl.minimum(maxi, x.to(tl.float32))
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..utils import pointwise_dynamic


@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
@triton.jit
def cos_func(x):
return tl.cos(x.to(tl.float32))
Expand Down
Loading