Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop Embedding[SiliconFlow] #72

Merged
merged 13 commits into from
Jul 18, 2024
21 changes: 21 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,24 @@ def test_perf_rand_like():
sizes=SIZES,
)
bench.run()


def test_perf_embedding():
def embedding_kwargs(dtype, batch, size):
input = torch.randint(0, batch, (batch,), device="cuda")
weight = torch.randn((batch + 1, size), device="cuda", dtype=dtype)
return {"input": input, "weight": weight}

bench = Benchmark(
op_name="embedding",
torch_op=torch.nn.functional.embedding,
arg_func=None,
dtypes=[
torch.float32,
torch.float16,
], # Note(Zhengzekang): triton do not support bfloat16 atomic add which is used in embedding grad.
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=embedding_kwargs,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def enable(lib=aten_lib):
lib.impl("cumsum", cumsum, "CUDA")
lib.impl("div.Tensor", div, "CUDA")
lib.impl("native_dropout", native_dropout, "AutogradCUDA")
lib.impl("embedding", embedding, "AutogradCUDA")
lib.impl("eq.Tensor", eq, "CUDA")
lib.impl("eq.Scalar", eq_scalar, "CUDA")
lib.impl("exp", exp, "CUDA")
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .cumsum import cumsum
from .div import div
from .dropout import native_dropout
from .embedding import embedding
from .eq import eq, eq_scalar
from .exp import exp
from .ge import ge, ge_scalar
Expand Down Expand Up @@ -87,6 +88,7 @@
"cumsum",
"div",
"native_dropout",
"embedding",
"eq",
"eq_scalar",
"exp",
Expand Down
192 changes: 192 additions & 0 deletions src/flag_gems/ops/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
import math

import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.jit
def embedding_kernel(
out_ptr, # pointer to the output
in_ptr, # pointer to the input
weight_ptr, # pointer to the weights
N: tl.constexpr, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
out_ptr += pid * N
in_ptr += pid

mask = tl.arange(0, BLOCK_SIZE) < N
cols = tl.arange(0, BLOCK_SIZE)

row_idx = tl.load(in_ptr)
weight_ptr += row_idx * N
embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0)
tl.store(out_ptr + cols, embedding_weight, mask)


@libentry()
@triton.jit
def indice_freq_kernel(
indices_freq,
indices, # pointer to the input
elem_cnt: tl.constexpr, # number of columns in X
INDICE_BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * INDICE_BLOCK_SIZE

offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE)
mask = offsets < elem_cnt

index_element = tl.load(indices + offsets, mask=mask)
tl.atomic_add(indices_freq + index_element, 1, mask=mask)


@libentry()
@triton.jit(do_not_specialize=["padding_idx"])
def embedding_backward_kernel(
grad_in, # pointer to the gradient input
grad_out, # pointer to the gradient output
indices, # pointer to the input
padding_idx, # padding_idx
HAS_PADDING_IDX: tl.constexpr,
N: tl.constexpr, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
grad_out += pid * N
indices += pid

mask = tl.arange(0, BLOCK_SIZE) < N
cols = tl.arange(0, BLOCK_SIZE)

row_idx = tl.load(indices).to(tl.int32)
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
if not HAS_PADDING_IDX:
grad_in += row_idx * N
embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
else:
if row_idx != padding_idx:
grad_in += row_idx * N
embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)


@libentry()
@triton.jit(do_not_specialize=["n_rows"])
def embedding_grad_scale_kernel(
grad_out,
indice_freq,
n_rows,
N,
BLOCK_SIZE: tl.constexpr,
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)

for row_idx in range(row_start, n_rows, row_step):
embedding_scale = 1.0
indice_freq_val = tl.load(indice_freq + row_idx)
if indice_freq_val > 1:
embedding_scale = 1.0 / indice_freq_val

cols = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < N
embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)
scaled_embedding_grad = embedding_grad * embedding_scale
tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)


class Embedding(torch.autograd.Function):
@staticmethod
def forward(
ctx, weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False
):
logging.debug("GEMS EMBEDDING FORWARD")
assert not sparse, "Currently do not support sparse format"

M = math.prod(indices.shape)
StrongSpoon marked this conversation as resolved.
Show resolved Hide resolved
N = weight.shape[-1]

BLOCK_SIZE = triton.next_power_of_2(N)
indices = indices.contiguous()
weight = weight.contiguous()
Copy link
Collaborator

@Bowen12992 Bowen12992 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to serialize here first? because it may cause memory copy overhead

output = torch.empty(
(*indices.shape, N), device=indices.device, dtype=weight.dtype
)

with torch.cuda.device(weight.device):
embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE)

if padding_idx is not None and padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx

ctx.M = M
ctx.N = N
ctx.num_weights = weight.shape[0]
ctx.padding_idx = padding_idx
ctx.scale_grad_by_freq = scale_grad_by_freq
ctx.sparse = sparse
ctx.indices = indices

return output

@staticmethod
def backward(ctx, grad_outputs):
logging.debug("GEMS EMBEDDING BACKWARD")
assert not ctx.sparse, "Currently do not support sparse format"

grad_inputs = torch.zeros(
(ctx.num_weights, grad_outputs.shape[-1]),
device=grad_outputs.device,
dtype=grad_outputs.dtype,
)

if ctx.scale_grad_by_freq:
indice_freq = torch.zeros(
(ctx.num_weights,),
requires_grad=False,
device=grad_outputs.device,
dtype=torch.int32,
)
INDICE_BLOCK_SIZE = 256
indice_grid = lambda meta: (triton.cdiv(ctx.M, INDICE_BLOCK_SIZE),)

with torch.cuda.device(grad_outputs.device):
indice_freq_kernel[indice_grid](
indice_freq, ctx.indices, ctx.M, INDICE_BLOCK_SIZE
)
else:
indice_freq = None

BLOCK_SIZE = triton.next_power_of_2(ctx.N)

HAS_PADDING_IDX = ctx.padding_idx is not None

with torch.cuda.device(grad_outputs.device):
embedding_backward_kernel[ctx.M,](
grad_inputs,
grad_outputs,
ctx.indices,
ctx.padding_idx,
HAS_PADDING_IDX,
ctx.N,
BLOCK_SIZE,
)

if ctx.scale_grad_by_freq:
with torch.cuda.device(grad_outputs.device):
embedding_grad_scale_kernel[ctx.M,](
grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE
)
return grad_inputs, None, None, None, None


def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
return Embedding.apply(weight, indices, padding_idx, scale_grad_by_freq, sparse)
36 changes: 36 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_apply_rotary_pos_emb(
position_ids=ref_position_ids if has_pos_id else None,
rotary_interleaved=rotary_interleaved,
)

q_embed_out, k_embed_out = flag_gems.apply_rotary_pos_emb(
q=q,
k=k,
Expand All @@ -158,6 +159,41 @@ def test_apply_rotary_pos_emb(
gems_assert_close(k_embed_out, k_embed_ref, dtype)


@pytest.mark.parametrize("EmbeddingSize", [4096])
@pytest.mark.parametrize("Batch", [2, 4])
@pytest.mark.parametrize("M", [4, 8])
@pytest.mark.parametrize("N", [128, 256, 4096])
@pytest.mark.parametrize("padding_idx", [None, -1, 1, 2])
@pytest.mark.parametrize("scale_grad_by_freq", [True, False])
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32]
) # triton.atomic_add still not support bf16
def test_embedding(EmbeddingSize, Batch, M, N, padding_idx, scale_grad_by_freq, dtype):
indices = torch.randint(
0, EmbeddingSize, (Batch, M), device="cuda", requires_grad=False
)
embedding = torch.randn(
(EmbeddingSize, N), device="cuda", dtype=dtype, requires_grad=True
)
ref_embedding = to_reference(embedding)

res_out = torch.nn.functional.embedding(
indices, embedding, padding_idx, scale_grad_by_freq=scale_grad_by_freq
)
with flag_gems.use_gems():
ref_out = torch.nn.functional.embedding(
indices, ref_embedding, padding_idx, scale_grad_by_freq=scale_grad_by_freq
)
out_grad = torch.randn_like(ref_out)
ref_grad = to_reference(out_grad)

(ref_in_grad,) = torch.autograd.grad(ref_out, ref_embedding, ref_grad)
(res_in_grad,) = torch.autograd.grad(res_out, embedding, out_grad)

gems_assert_close(ref_out, res_out, dtype)
gems_assert_close(ref_in_grad, res_in_grad, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_rand(shape, dtype):
Expand Down