Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop Embedding[SiliconFlow] #72

Merged
merged 13 commits into from
Jul 18, 2024
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .cumsum import cumsum
from .div import div
from .dropout import native_dropout
from .embedding import embedding
from .eq import eq, eq_scalar
from .exp import exp
from .ge import ge, ge_scalar
Expand Down Expand Up @@ -82,6 +83,7 @@
"cumsum",
"div",
"native_dropout",
"embedding",
"eq",
"eq_scalar",
"exp",
Expand Down
188 changes: 188 additions & 0 deletions src/flag_gems/ops/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import logging
import math

import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.jit
def embedding_kernel(
Y, # pointer to the output
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
X, # pointer to the input
W, # pointer to the weights
N: tl.constexpr, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
Y += pid * N
X += pid

mask = tl.arange(0, BLOCK_SIZE) < N
cols = tl.arange(0, BLOCK_SIZE)

row_idx = tl.load(X).to(tl.int32)
W += row_idx * N
embedding_weight = tl.load(W + cols, mask, other=0.0)
tl.store(Y + cols, embedding_weight, mask)


@libentry()
@triton.jit
def indice_freq_kernel(
indices_freq, # indice frequency
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
indices, # pointer to the input
elem_cnt: tl.constexpr, # number of columns in X
INDICE_BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * INDICE_BLOCK_SIZE

offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE)
mask = offsets < elem_cnt

index_element = tl.load(indices + offsets, mask=mask).to(tl.int32)
tl.atomic_add(indices_freq + index_element, 1, mask=mask)


@libentry()
@triton.jit(do_not_specialize=["padding_idx"])
def embedding_backward_kernel(
GradIn, # pointer to the gradient input
GradOut, # pointer to the gradient output
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
indices, # pointer to the input
padding_idx, # padding_idx
HAS_PADDING_IDX: tl.constexpr,
N: tl.constexpr, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
GradOut += pid * N
indices += pid

mask = tl.arange(0, BLOCK_SIZE) < N
cols = tl.arange(0, BLOCK_SIZE)

row_idx = tl.load(indices).to(tl.int32)
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
if not HAS_PADDING_IDX:
GradIn += row_idx * N
embedding_grad = tl.load(GradOut + cols, mask, other=0.0)
tl.atomic_add(GradIn + cols, embedding_grad, mask=mask)
else:
if row_idx != padding_idx:
GradIn += row_idx * N
embedding_grad = tl.load(GradOut + cols, mask, other=0.0)
tl.atomic_add(GradIn + cols, embedding_grad, mask=mask)


@libentry()
@triton.jit(do_not_specialize=["n_rows", "N"])
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
def embedding_grad_scale_kernel(
grad_out, # indice frequency
indice_freq, # pointer to the input
n_rows,
N,
BLOCK_SIZE: tl.constexpr,
):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)

for row_idx in range(row_start, n_rows, row_step):
embedding_scale = 1.0
indice_freq_val = tl.load(indice_freq + row_idx).to(tl.int32)
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
if indice_freq_val > 1:
embedding_scale = 1.0 / indice_freq_val

cols = tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < N
embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask)
scaled_embedding_grad = embedding_grad * embedding_scale
tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask)


class Embedding(torch.autograd.Function):
@staticmethod
def forward(
ctx, weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False
):
logging.debug("GEMS EMBEDDING FORWARD")
assert not sparse, "Currently do not support sparse format"

M = math.prod(indices.shape)
StrongSpoon marked this conversation as resolved.
Show resolved Hide resolved
N = weight.shape[-1]

BLOCK_SIZE = triton.next_power_of_2(N)
indices = indices.contiguous()
weight = weight.contiguous()
Copy link
Collaborator

@Bowen12992 Bowen12992 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to serialize here first? because it may cause memory copy overhead

output = torch.empty(
(*indices.shape, N), device=indices.device, dtype=weight.dtype
)

embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE)
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved

if padding_idx is not None and padding_idx < 0:
padding_idx = weight.shape[0] + padding_idx

ctx.M = M
ctx.N = N
ctx.num_weights = weight.shape[0]
ctx.padding_idx = padding_idx
ctx.scale_grad_by_freq = scale_grad_by_freq
ctx.sparse = sparse
ctx.indices = indices

return output

@staticmethod
def backward(ctx, grad_outputs):
logging.debug("GEMS EMBEDDING BACKWARD")
assert not ctx.sparse, "Currently do not support sparse format"

grad_inputs = torch.zeros(
(ctx.num_weights, grad_outputs.shape[-1]),
device=grad_outputs.device,
dtype=grad_outputs.dtype,
)

if ctx.scale_grad_by_freq:
indice_freq = torch.zeros(
(ctx.num_weights,),
requires_grad=False,
device=grad_outputs.device,
dtype=torch.int32,
)
INDICE_BLOCK_SIZE = 256
indice_grid = lambda meta: (triton.cdiv(ctx.M, INDICE_BLOCK_SIZE),)
indice_freq_kernel[indice_grid](
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
indice_freq, ctx.indices, ctx.M, INDICE_BLOCK_SIZE
)
else:
indice_freq = None

BLOCK_SIZE = triton.next_power_of_2(ctx.N)

HAS_PADDING_IDX = ctx.padding_idx is not None
embedding_backward_kernel[ctx.M,](
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
grad_inputs,
grad_outputs,
ctx.indices,
ctx.padding_idx,
HAS_PADDING_IDX,
ctx.N,
BLOCK_SIZE,
)

if ctx.scale_grad_by_freq:
embedding_grad_scale_kernel[ctx.M,](
grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE
)
return grad_inputs, None, None, None, None


def embedding(
indices, weight, padding_idx=None, scale_grad_by_freq=False, sparse=False
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
):
return Embedding.apply(weight, indices, padding_idx, scale_grad_by_freq, sparse)
40 changes: 40 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_apply_rotary_pos_emb(
position_ids=ref_position_ids if has_pos_id else None,
rotary_interleaved=rotary_interleaved,
)

q_embed_out, k_embed_out = flag_gems.apply_rotary_pos_emb(
q=q,
k=k,
Expand All @@ -156,3 +157,42 @@ def test_apply_rotary_pos_emb(

gems_assert_close(q_embed_out, q_embed_ref, dtype)
gems_assert_close(k_embed_out, k_embed_ref, dtype)


@pytest.mark.parametrize("EmbeddingSize", [4096])
@pytest.mark.parametrize("Batch", [2, 4])
@pytest.mark.parametrize("M", [4, 8])
@pytest.mark.parametrize("N", [128, 256, 4096])
@pytest.mark.parametrize("padding_idx", [None, -1, 1, 2])
@pytest.mark.parametrize("scale_grad_by_freq", [True, False])
@pytest.mark.parametrize(
"dtype", [torch.float32]
) # triton.atomic_add still not support bf16
def test_embedding(EmbeddingSize, Batch, M, N, padding_idx, scale_grad_by_freq, dtype):
torch.manual_seed(0)
indices = torch.randint(
0, EmbeddingSize, (Batch, M), device="cuda", requires_grad=False
)
embedding = torch.randn(
(EmbeddingSize, N), device="cuda", dtype=dtype, requires_grad=True
)
ref_embedding = to_reference(embedding)

res_out = torch.nn.functional.embedding(
indices, embedding, padding_idx, scale_grad_by_freq=scale_grad_by_freq
)
ref_out = flag_gems.embedding(
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved
indices, ref_embedding, padding_idx, scale_grad_by_freq=scale_grad_by_freq
)

out_grad = torch.randn_like(ref_out)
ref_grad = to_reference(out_grad)

(ref_in_grad,) = torch.autograd.grad(ref_out, ref_embedding, ref_grad)
(res_in_grad,) = torch.autograd.grad(res_out, embedding, out_grad)

res_out = to_reference(res_out)
res_in_grad = to_reference(res_in_grad)
MARD1NO marked this conversation as resolved.
Show resolved Hide resolved

gems_assert_close(ref_out, res_out, dtype)
gems_assert_close(ref_in_grad, res_in_grad, dtype)
Loading