Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cambricon #87

Draft
wants to merge 53 commits into
base: master
Choose a base branch
from
Draft

Cambricon #87

wants to merge 53 commits into from

Conversation

FuncSherl
Copy link
Collaborator

draft for compare code

xuhao and others added 30 commits May 23, 2024 06:39
cambricon: fix mlu adopt problems
cambricon: fix add/sub ops test
cambricon: merge maser on 0527 and fix some mlu problems
cambricon: fix bugs in mean dim
cambricon: fix bugs in grid over limits
cambricon: Cambricon merge 0603 and adopt mlu
cambricon: fix bugs in vectnorm and varmean
cambricon: fix some bugs in tests
return inp1, inp2, inp3


def cross_entropy_loss_args(dtype, batch, size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function cross_entropy_loss_args, cumsum_args and so on have been implemented in the corresponding test functions. they could be deleted here.

Copy link
Collaborator Author

@FuncSherl FuncSherl Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will do in #80

device = device or torch.mlu.current_device()
gen = torch.mlu.default_generators[device]
state_copy = gen.get_state()
c0, c1 = state_copy.view(torch.int64)[-2:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how many bits is state_copy? is it a list longer than 2 after viewed as torch.int64?

@@ -0,0 +1,303 @@
from itertools import chain
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pointwise_static.py is no longer useful. you could delete this file.

@@ -0,0 +1,47 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all forms of function pow are collected in pow.py . other files could be deleted.

@@ -0,0 +1,47 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -0,0 +1,16 @@
import triton
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

ref_inp2 = to_reference(inp2, True)

ref_out = torch.pow(inp1, ref_inp2)
ref_out = torch.pow(inp1, ref_inp2.cpu())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not run reference on mlu?

@@ -0,0 +1,105 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable.py is no longer needed

@FuncSherl
Copy link
Collaborator Author

该mr用于对比当前cambricon和master的代码差距,便于讲解代码修改,不用于代码和入

@FuncSherl FuncSherl closed this Jul 1, 2024
@FuncSherl FuncSherl reopened this Jul 1, 2024

raw_res = tl.cumsum(inp_vals, axis=1)
result = raw_res + kep[:, None]
kep = result[:, BLOCK_N-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does triton support tensor slice?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same question here. Tensor slicing of this type is not supported in triton 2.2.

inp = tl.load(input_ptrs, mask=mask, other=-
float("inf")).to(tl.float32)
# get max for each block
tmp1 = tl.where(tmp0 < inp, inp, tmp0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.maximum

@@ -24,15 +24,18 @@ def __init__(self, op_name, torch_op, arg_func, dtype, batch, sizes):
def set_gems(self, gems_op):
self.gems_op = gems_op

def set_gems(self, gems_op):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this duplication intended?

task_num = tl.cdiv(M, BLOCK_M)
iter_num = tl.cdiv(task_num, num_prog)
if task_num % num_prog != 0:
iter_num = iter_num + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it conflict with tl.cdiv ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding 1 when iter_num is already tl.cdiv(task_num, num_prog)?

task_num = tl.cdiv(M, BLOCK_M)
iter_num = tl.cdiv(task_num, num_prog)
if task_num % num_prog != 0:
iter_num = iter_num + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -11,7 +11,7 @@
@triton.jit
def gelu_none_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
x_gelu = 0.5 * x_fp32 * (1 + tl.math.erf(x_fp32 * 0.7071067811))
x_gelu = 0.5 * x_fp32 * (1 + tl.extra.mlu.libdevice.erf(x_fp32 * 0.7071067811))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What version of triton are you targeting now? It seems that this path to erf is in versions after 2.3.

"M",
"N",
],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using only M & N in tunning key while M, N &K may all affect the performance would cause some unexpected behavior. For example, the best config tunned for (m, n, k1) may not be the same config tunned from (m, n, k2), so previous runs have effects on peformance afterwards.

@@ -58,6 +60,96 @@ def log_softmax_kernel(
tl.store(output_ptrs, softmax_output, mask=mask)


@libentry()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a new implementation in #76, which improves performance a lot, maybe you can test that too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants