Skip to content

Commit

Permalink
Merge pull request #72 from FlagOpen/dev_embedding
Browse files Browse the repository at this point in the history
Develop Embedding[SiliconFlow]
  • Loading branch information
Bowen12992 authored Jul 18, 2024
2 parents a5f2764 + 40e0731 commit 7156a8f
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 0 deletions.
21 changes: 21 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,24 @@ def test_perf_rand_like():
sizes=SIZES,
)
bench.run()


def test_perf_embedding():
def embedding_kwargs(dtype, batch, size):
input = torch.randint(0, batch, (batch,), device="cuda")
weight = torch.randn((batch + 1, size), device="cuda", dtype=dtype)
return {"input": input, "weight": weight}

bench = Benchmark(
op_name="embedding",
torch_op=torch.nn.functional.embedding,
arg_func=None,
dtypes=[
torch.float32,
torch.float16,
], # Note(Zhengzekang): triton do not support bfloat16 atomic add which is used in embedding grad.
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=embedding_kwargs,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def enable(lib=aten_lib):
lib.impl("cumsum", cumsum, "CUDA")
lib.impl("div.Tensor", div, "CUDA")
lib.impl("native_dropout", native_dropout, "AutogradCUDA")
lib.impl("embedding", embedding, "AutogradCUDA")
lib.impl("eq.Tensor", eq, "CUDA")
lib.impl("eq.Scalar", eq_scalar, "CUDA")
lib.impl("exp", exp, "CUDA")
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .cumsum import cumsum
from .div import div
from .dropout import native_dropout
from .embedding import embedding
from .eq import eq, eq_scalar
from .exp import exp
from .ge import ge, ge_scalar
Expand Down Expand Up @@ -88,6 +89,7 @@
"cumsum",
"div",
"native_dropout",
"embedding",
"eq",
"eq_scalar",
"exp",
Expand Down
192 changes: 192 additions & 0 deletions src/flag_gems/ops/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
import math

import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.jit
def embedding_kernel(
out_ptr, # pointer to the output
in_ptr, # pointer to the input
weight_ptr, # pointer to the weights
N: tl.constexpr, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
out_ptr += pid * N
in_ptr += pid

mask = tl.arange(0, BLOCK_SIZE) < N
cols = tl.arange(0, BLOCK_SIZE)

row_idx = tl.load(in_ptr)
weight_ptr += row_idx * N
embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0)
tl.store(out_ptr + cols, embedding_weight, mask)


@libentry()
@triton.jit
def indice_freq_kernel(
indices_freq,
indices, # pointer to the input
elem_cnt: tl.constexpr, # number of columns in X
INDICE_BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * INDICE_BLOCK_SIZE

offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE)
mask = offsets < elem_cnt

index_element = tl.load(indices + offsets, mask=mask)
tl.atomic_add(indices_freq + index_element, 1, mask=mask)


@libentry()
@triton.jit(do_not_specialize=["padding_idx"])
def embedding_backward_kernel(
grad_in, # pointer to the gradient input
grad_out, # pointer to the gradient output
indices, # pointer to the input
padding_idx, # padding_idx
HAS_PADDING_IDX: tl.constexpr,
N: tl.constexpr, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
grad_out += pid * N
indices += pid

mask = tl.arange(0, BLOCK_SIZE) < N
cols = tl.arange(0, BLOCK_SIZE)

row_idx = tl.load(indices).to(tl.int32)
if not HAS_PADDING_IDX:
grad_in += row_idx * N
embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)
else:
if row_idx != padding_idx:
grad_in += row_idx * N
embedding_grad = tl.load(grad_out + cols, mask, other=0.0)
tl.atomic_add(grad_in + cols, embedding_grad, mask=mask)


@libentry()
@triton.jit(do_not_specialize=["n_rows"])
def embedding_grad_scale_kernel(
grad_out,
indice_freq,
n_rows,
N,
BLOCK_SIZE: tl.constexpr,
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)

for row_idx in range(row_start, n_rows, row_step):
embedding_scale = 1.0
indice_freq_val = tl.load(indice_freq + row_idx)
if indice_freq_val > 1:
embedding_scale = 1.0 / indice_freq_val

cols = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < N
embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)
scaled_embedding_grad = embedding_grad * embedding_scale
tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)


class Embedding(torch.autograd.Function):
@staticmethod
def forward(
ctx, weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False
):
logging.debug("GEMS EMBEDDING FORWARD")
assert not sparse, "Currently do not support sparse format"

M = math.prod(indices.shape)
N = weight.shape[-1]

BLOCK_SIZE = triton.next_power_of_2(N)
indices = indices.contiguous()
weight = weight.contiguous()
output = torch.empty(
(*indices.shape, N), device=indices.device, dtype=weight.dtype
)

with torch.cuda.device(weight.device):
embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE)

if padding_idx is not None and padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx

ctx.M = M
ctx.N = N
ctx.num_weights = weight.shape[0]
ctx.padding_idx = padding_idx
ctx.scale_grad_by_freq = scale_grad_by_freq
ctx.sparse = sparse
ctx.indices = indices

return output

@staticmethod
def backward(ctx, grad_outputs):
logging.debug("GEMS EMBEDDING BACKWARD")
assert not ctx.sparse, "Currently do not support sparse format"

grad_inputs = torch.zeros(
(ctx.num_weights, grad_outputs.shape[-1]),
device=grad_outputs.device,
dtype=grad_outputs.dtype,
)

if ctx.scale_grad_by_freq:
indice_freq = torch.zeros(
(ctx.num_weights,),
requires_grad=False,
device=grad_outputs.device,
dtype=torch.int32,
)
INDICE_BLOCK_SIZE = 256
indice_grid = lambda meta: (triton.cdiv(ctx.M, INDICE_BLOCK_SIZE),)

with torch.cuda.device(grad_outputs.device):
indice_freq_kernel[indice_grid](
indice_freq, ctx.indices, ctx.M, INDICE_BLOCK_SIZE
)
else:
indice_freq = None

BLOCK_SIZE = triton.next_power_of_2(ctx.N)

HAS_PADDING_IDX = ctx.padding_idx is not None

with torch.cuda.device(grad_outputs.device):
embedding_backward_kernel[ctx.M,](
grad_inputs,
grad_outputs,
ctx.indices,
ctx.padding_idx,
HAS_PADDING_IDX,
ctx.N,
BLOCK_SIZE,
)

if ctx.scale_grad_by_freq:
with torch.cuda.device(grad_outputs.device):
embedding_grad_scale_kernel[ctx.M,](
grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE
)
return grad_inputs, None, None, None, None


def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
return Embedding.apply(weight, indices, padding_idx, scale_grad_by_freq, sparse)
36 changes: 36 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_apply_rotary_pos_emb(
position_ids=ref_position_ids if has_pos_id else None,
rotary_interleaved=rotary_interleaved,
)

q_embed_out, k_embed_out = flag_gems.apply_rotary_pos_emb(
q=q,
k=k,
Expand All @@ -158,6 +159,41 @@ def test_apply_rotary_pos_emb(
gems_assert_close(k_embed_out, k_embed_ref, dtype)


@pytest.mark.parametrize("EmbeddingSize", [4096])
@pytest.mark.parametrize("Batch", [2, 4])
@pytest.mark.parametrize("M", [4, 8])
@pytest.mark.parametrize("N", [128, 256, 4096])
@pytest.mark.parametrize("padding_idx", [None, -1, 1, 2])
@pytest.mark.parametrize("scale_grad_by_freq", [True, False])
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32]
) # triton.atomic_add still not support bf16
def test_embedding(EmbeddingSize, Batch, M, N, padding_idx, scale_grad_by_freq, dtype):
indices = torch.randint(
0, EmbeddingSize, (Batch, M), device="cuda", requires_grad=False
)
embedding = torch.randn(
(EmbeddingSize, N), device="cuda", dtype=dtype, requires_grad=True
)
ref_embedding = to_reference(embedding)

res_out = torch.nn.functional.embedding(
indices, embedding, padding_idx, scale_grad_by_freq=scale_grad_by_freq
)
with flag_gems.use_gems():
ref_out = torch.nn.functional.embedding(
indices, ref_embedding, padding_idx, scale_grad_by_freq=scale_grad_by_freq
)
out_grad = torch.randn_like(ref_out)
ref_grad = to_reference(out_grad)

(ref_in_grad,) = torch.autograd.grad(ref_out, ref_embedding, ref_grad)
(res_in_grad,) = torch.autograd.grad(res_out, embedding, out_grad)

gems_assert_close(ref_out, res_out, dtype)
gems_assert_close(ref_in_grad, res_in_grad, dtype)


@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_rand(shape, dtype):
Expand Down

0 comments on commit 7156a8f

Please sign in to comment.