-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cambricon #87
base: master
Are you sure you want to change the base?
Conversation
cambricon: fix mlu adopt problems
cambricon: fix add/sub ops test
cambricon: merge maser on 0527 and fix some mlu problems
cambricon: fix bugs in mean dim
cambricon: fix bugs in grid over limits
cambricon: Cambricon merge 0603 and adopt mlu
cambricon: fix bugs in vectnorm and varmean
cambricon: rm old tests
cambricon: fix some bugs in tests
Cambricon merge 0611
benchmark/performance_utils.py
Outdated
return inp1, inp2, inp3 | ||
|
||
|
||
def cross_entropy_loss_args(dtype, batch, size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function cross_entropy_loss_args, cumsum_args and so on have been implemented in the corresponding test functions. they could be deleted here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will do in #80
device = device or torch.mlu.current_device() | ||
gen = torch.mlu.default_generators[device] | ||
state_copy = gen.get_state() | ||
c0, c1 = state_copy.view(torch.int64)[-2:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how many bits is state_copy? is it a list longer than 2 after viewed as torch.int64?
@@ -0,0 +1,303 @@ | |||
from itertools import chain |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pointwise_static.py is no longer useful. you could delete this file.
src/flag_gems/ops/pow_scalar.py
Outdated
@@ -0,0 +1,47 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all forms of function pow are collected in pow.py . other files could be deleted.
@@ -0,0 +1,47 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@@ -0,0 +1,16 @@ | |||
import triton |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
ref_inp2 = to_reference(inp2, True) | ||
|
||
ref_out = torch.pow(inp1, ref_inp2) | ||
ref_out = torch.pow(inp1, ref_inp2.cpu()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not run reference on mlu?
src/flag_gems/__enable__.py
Outdated
@@ -0,0 +1,105 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable.py is no longer needed
该mr用于对比当前cambricon和master的代码差距,便于讲解代码修改,不用于代码和入 |
|
||
raw_res = tl.cumsum(inp_vals, axis=1) | ||
result = raw_res + kep[:, None] | ||
kep = result[:, BLOCK_N-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does triton support tensor slice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same question here. Tensor slicing of this type is not supported in triton 2.2.
src/flag_gems/ops/log_softmax.py
Outdated
inp = tl.load(input_ptrs, mask=mask, other=- | ||
float("inf")).to(tl.float32) | ||
# get max for each block | ||
tmp1 = tl.where(tmp0 < inp, inp, tmp0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tl.maximum
benchmark/performance_utils.py
Outdated
@@ -24,15 +24,18 @@ def __init__(self, op_name, torch_op, arg_func, dtype, batch, sizes): | |||
def set_gems(self, gems_op): | |||
self.gems_op = gems_op | |||
|
|||
def set_gems(self, gems_op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this duplication intended?
src/flag_gems/ops/mean.py
Outdated
task_num = tl.cdiv(M, BLOCK_M) | ||
iter_num = tl.cdiv(task_num, num_prog) | ||
if task_num % num_prog != 0: | ||
iter_num = iter_num + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it conflict with tl.cdiv ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why adding 1 when iter_num is already tl.cdiv(task_num, num_prog)?
src/flag_gems/ops/var_mean.py
Outdated
task_num = tl.cdiv(M, BLOCK_M) | ||
iter_num = tl.cdiv(task_num, num_prog) | ||
if task_num % num_prog != 0: | ||
iter_num = iter_num + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@@ -11,7 +11,7 @@ | |||
@triton.jit | |||
def gelu_none_and_mul_kernel(x, y): | |||
x_fp32 = x.to(tl.float32) | |||
x_gelu = 0.5 * x_fp32 * (1 + tl.math.erf(x_fp32 * 0.7071067811)) | |||
x_gelu = 0.5 * x_fp32 * (1 + tl.extra.mlu.libdevice.erf(x_fp32 * 0.7071067811)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What version of triton are you targeting now? It seems that this path to erf is in versions after 2.3.
"M", | ||
"N", | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using only M & N in tunning key while M, N &K may all affect the performance would cause some unexpected behavior. For example, the best config tunned for (m, n, k1) may not be the same config tunned from (m, n, k2), so previous runs have effects on peformance afterwards.
@@ -58,6 +60,96 @@ def log_softmax_kernel( | |||
tl.store(output_ptrs, softmax_output, mask=mask) | |||
|
|||
|
|||
@libentry() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a new implementation in #76, which improves performance a lot, maybe you can test that too.
cambricon: perf test opt
Cambricon merge 0708
draft for compare code