Skip to content

Commit

Permalink
change docs for pointwise_dy
Browse files Browse the repository at this point in the history
  • Loading branch information
Bowen12992 committed Jun 26, 2024
1 parent d4c36d6 commit 26147cb
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 7 deletions.
31 changes: 28 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ In FlagGems, we provide automatic code generation that developers can use to con
Decorating the pointwise operator function with `pointwise_dynamic` can save the manual handling of tensor addressing, tensor read/write, parallel tiling, tensor broadcasting, dynamic dimensions, non-contiguous storage, etc. For example, in the following code, developers only need to describe the computational logic to generate flexible and efficient Triton code.

```python
@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]])
@triton.jit
def abs_func(x):
return tl.abs(x)
Expand All @@ -29,15 +29,19 @@ def abs_func(x):
By default, `pointwise_dynamic` treats all parameters as tensors, and by passing a list of boolean values to the parameter `is_tensor`, developers can specify which parameters are tensors and which are not. Additionally, developers can pass in `dtypes` to indicate the data types of non-tensor parameters, but this is not required. For example, in the following code, the `alpha` parameter is defined as a non-tensor floating point number, while the `x` and `y` parameters are defined as tensors.

```python
@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float])
@pointwise_dynamic(
is_tensor=[True, True, False],
dtypes=[None, None, float],
promotion_methods=[[0,"DEFAULT"]]
)
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha
```

#### Output Data Type

By default, all output tensors have the same data type as the first input tensor, but it can also be customized by providing a list of data types to the parameter `output_dtypes`. For example, in the following code, the output tensor type is specified as `torch.bool`.
Furthermore, developers MUST provide promotion_methods to specify how type promotion should be handled for the operation to achieve the correct output type during computation.

```python
@pointwise_dynamic(output_dtypes=[torch.bool])
Expand All @@ -46,6 +50,27 @@ def ge(x, y):
return x > y
```

In `promotion_methods`, an `int` is used to indicate the position of the parameter requiring type promotion, while a `str` denotes the method of type promotion. The `str` corresponds to the following enumerated types:

```python
class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
DEFAULT = (0,)
NO_OPMATH = (1,)
INT_TO_FLOAT = (2,)
ALWAYS_BOOL = (3,)
COMPLEX_TO_FLOAT = (4,)
BOOL_TO_LONG = (5,)
```

Examples:

- `DEFAULT` :add
- `NO_OPMATH` : where, nextafter, cat
- `INT_TO_FLOAT` :sin
- `ALWAYS_BOOL` :eq
- `COMPLEX_TO_FLOAT` :abs
- `BOOL_TO_LONG` :pow

## Changelog

### v1.0
Expand Down
33 changes: 29 additions & 4 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ FlagGems通过对PyTorch的后端aten算子进行覆盖重写,实现算子库
在对位算子函数前装饰`pointwise_dynamic`,可以节省张量寻址、张量读写、并行分块、张量广播、动态维度、非连续存储等的手动处理。例如以下代码,开发者只需简单描述计算逻辑,即可生成灵活高效的Triton核函数与包装代码。

```python
@pointwise_dynamic
@pointwise_dynamic(promotion_methods=[[0, "COMPLEX_TO_FLOAT"]])
@triton.jit
def abs_func(x):
return tl.abs(x)
Expand All @@ -29,23 +29,48 @@ def abs_func(x):
在默认情况下,`pointwise_dynamic`将所有参数均处理为张量,而通过向参数`is_tensor`传递布尔值列表,开发者可以指定哪些参数是张量,哪些参数非张量。此外,开发者还可以传入`dtypes`说明非张量参数的数据类型,但这不是必要的。例如以下代码,将`alpha`参数定义为非张量的浮点数,而`x``y`参数定义为张量。

```python
@pointwise_dynamic(is_tensor=[True, True, False], dtypes=[None, None, float])
@pointwise_dynamic(
is_tensor=[True, True, False],
dtypes=[None, None, float],
promotion_methods=[[0,"DEFAULT"]]
)
@triton.jit
def add_func(x, y, alpha):
return x + y * alpha
```

#### 输出数据类型

在默认情况下,输出张量使用与首个输入张量相同的数据类型,但也可向参数`output_dtypes`传入数据类型组成的列表来指定。例如以下代码,指定输出张量类型为`torch.bool`
此外,开发者必须传入 `promotion_methods` 来说明该 Op 在进行计算时应该如何进行`类型提升`以获得正确的输出类型

```python
@pointwise_dynamic(output_dtypes=[torch.bool])
@pointwise_dynamic(promotion_methods=[[0, "ALWAYS_BOOL"]])
@triton.jit
def ge(x, y):
return x > y
```

`promotion_methods` 通过传入 `int` 来表示需要进行类型提升的参数位置, 通过传入 `str` 来表示类型提升的方式, `str` 对于以下枚举类型

```python
class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
DEFAULT = (0,)
NO_OPMATH = (1,)
INT_TO_FLOAT = (2,)
ALWAYS_BOOL = (3,)
COMPLEX_TO_FLOAT = (4,)
BOOL_TO_LONG = (5,)
```

举例:

- `DEFAULT` :add
- `NO_OPMATH` : where, nextafter, cat
- `INT_TO_FLOAT` :sin
- `ALWAYS_BOOL` :eq
- `COMPLEX_TO_FLOAT` :abs
- `BOOL_TO_LONG` :pow

## 更新日志

### v1.0
Expand Down

0 comments on commit 26147cb

Please sign in to comment.