Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinomial #141

Merged
merged 56 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
1575860
WIP: multinomial
tongxin Jul 28, 2024
4e4879e
add Ops & UT & Bench
Jul 29, 2024
00da77e
Merge branch 'master' of https://github.com/FlagOpen/FlagGems into op…
Jul 29, 2024
22e14de
add full zero ones Ops & UT & Bench
Jul 30, 2024
5063551
split normal op
Jul 30, 2024
55987a6
Adding multinomial.
tongxin Jul 30, 2024
0f6822a
fixed one off error in binary search
tongxin Jul 31, 2024
c894eb3
Added multinomial tests without replacement.
tongxin Aug 1, 2024
521dd05
PR comment
Aug 1, 2024
f39ae12
Merge branch 'master' of https://github.com/FlagOpen/FlagGems into op…
Aug 1, 2024
732edca
split test_special_ops
Aug 1, 2024
85e93b3
updated with_replacement tests
tongxin Aug 1, 2024
145ed76
add K-S test
Aug 1, 2024
273efc8
split special perf
Aug 1, 2024
f46da90
Update to a more reliable without-replacement test
tongxin Aug 1, 2024
4a5acdd
Exponential added. (#138)
tongxin Aug 1, 2024
2f725c4
table
Aug 2, 2024
f1b27d1
resolve conflict
Aug 2, 2024
9e878c4
Use int64 indexing when needed & fix argmax (#146)
iclementine Aug 2, 2024
371f108
test for op
Aug 2, 2024
6fcc8e7
test for op
Aug 2, 2024
a7feb4b
Added multinomial perf tests.
tongxin Aug 4, 2024
369aa82
Making libentry thread safe (#136)
tongxin Aug 5, 2024
423731d
add argparse
Aug 6, 2024
0da6086
fix desc
Aug 6, 2024
de4b098
fix num
Aug 6, 2024
980c9d7
Update test_specific_ops.py
Bowen12992 Aug 6, 2024
e77db4c
Merge pull request #151 from Bowen12992/test_for_op
Bowen12992 Aug 6, 2024
22654f1
split UT files
Aug 6, 2024
78bdaf2
Merge branch 'master' of https://github.com/FlagOpen/FlagGems into op…
Aug 6, 2024
0dde1a6
fix
Aug 6, 2024
34c7522
fix
Aug 6, 2024
7042de1
Merge pull request #139 from Bowen12992/op_dev
Bowen12992 Aug 6, 2024
dc9ce0c
Merge branch 'master' into multinomial
tongxin Aug 6, 2024
ce37522
resolved conflicts with master.
tongxin Aug 6, 2024
376344c
Move multinomial hypothesis tests to test_distribution_ops, resolve c…
tongxin Aug 7, 2024
db5dd1e
fixing multinomial, working in progress.
tongxin Aug 19, 2024
6c6a612
Multinomial passes tests.
tongxin Aug 19, 2024
3d40f7a
Enhance multinomial tests and benchmarks.
tongxin Aug 19, 2024
134bef1
resolve conflicts.
tongxin Aug 19, 2024
e25e04e
Merge branch 'master' into multinomial
StrongSpoon Aug 20, 2024
e7b2c5d
[bugfix] keepdim when samples one
StrongSpoon Aug 20, 2024
47e76ad
[bugfix] fix accu test
StrongSpoon Aug 21, 2024
55c84f7
fix anomaly behavior in fused_renorm_cumsum
tongxin Aug 25, 2024
aa66384
Polish multinomial tests.
tongxin Aug 26, 2024
650981d
remove garbage files.
tongxin Aug 26, 2024
71f1860
multinomial result casted to int64 for n_samples==1
tongxin Aug 26, 2024
82e3a4d
Merge with master
tongxin Aug 26, 2024
a8f9a47
bfloat16 added for multinomial, polish without replacement test.
tongxin Aug 28, 2024
3f6ec74
merged with master.
tongxin Aug 28, 2024
afd2c0d
Enable two-pass normed cumsum.
tongxin Aug 31, 2024
bfd62e5
cumsum updated
tongxin Aug 31, 2024
572853c
normed cumsum complete.
tongxin Sep 1, 2024
85a472e
Fixed multinomial binary search boundary bug
tongxin Sep 1, 2024
eafe33d
fix normed_cumsum bugs.
tongxin Sep 1, 2024
459bb61
quick fix dim check.
tongxin Sep 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def unique_kwargs(dtype, batch, size):
bench.run()


def test_multinomial_with_replacement():
def multinomial_args(dtype, batch, size):
dist = torch.rand((batch, size), dtype=dtype, device="cuda")
n_samples = 10000
return (dist, n_samples, True)

bench = Benchmark(
op_name="multinomial",
torch_op=torch.multinomial,
arg_func=multinomial_args,
dtypes=(torch.float16, torch.float32),
tongxin marked this conversation as resolved.
Show resolved Hide resolved
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_pad():
def padding_kwargs(dtype, batch, size):
input = torch.randn((batch, size), device="cuda", dtype=dtype)
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def enable(lib=aten_lib):
lib.impl("mean.dim", mean_dim, "CUDA")
lib.impl("mm", mm, "CUDA")
lib.impl("mul.Tensor", mul, "CUDA")
lib.impl("multinomial", multinomial, "CUDA")
lib.impl("mv", mv, "CUDA")
lib.impl("ne.Tensor", ne, "CUDA")
lib.impl("ne.Scalar", ne_scalar, "CUDA")
Expand Down
5 changes: 4 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .clamp import clamp, clamp_tensor
from .cos import cos
from .cross_entropy_loss import cross_entropy_loss
from .cumsum import cumsum
from .cumsum import cumsum, normed_cumsum
from .div import div_mode, floor_divide, true_divide
from .dropout import native_dropout
from .embedding import embedding
Expand Down Expand Up @@ -48,6 +48,7 @@
from .minimum import minimum
from .mm import mm
from .mul import mul
from .multinomial import multinomial
from .mv import mv
from .ne import ne, ne_scalar
from .neg import neg
Expand Down Expand Up @@ -115,6 +116,7 @@
"cos",
"pad",
"cumsum",
"normed_cumsum",
"true_divide",
"div_mode",
"floor_divide",
Expand Down Expand Up @@ -153,6 +155,7 @@
"mean_dim",
"mm",
"mul",
"multinomial",
"maximum",
"minimum",
"rand",
Expand Down
257 changes: 257 additions & 0 deletions src/flag_gems/ops/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,260 @@ def cumsum(inp, dim=1, *, dtype=None):
with torch.cuda.device(inp.device):
cumsum_kernel[grid](inp, out, M, N, K)
return out


@libentry()
@triton.jit(do_not_specialize=["K"])
def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
row_start = tl.program_id(0) * K
row_off = tl.arange(0, BLOCK)
x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
if x.dtype.is_fp16():
x = x.to(tl.float32)
y_sum = tl.sum(x, 0)
y = tl.cumsum(x, 0)
y = y / y_sum
tl.store(out + row_start + row_off, y, mask=row_off < K)


@libentry()
@triton.jit(
do_not_specialize=[
"r",
"t",
"R",
"K",
"r_stride",
"out_r_stride",
]
)
def block_cumsum_kernel(
inp,
out,
sums,
r,
t,
R,
K,
r_stride,
k_stride,
out_r_stride,
out_k_stride,
OUTPUT_SUMS: tl.constexpr,
NORMALIZE: tl.constexpr,
HAS_OUT_LAYOUT: tl.constexpr,
TILE: tl.constexpr,
):
# One CTA processes a (r, t*tile) chunk
# rows = [ grid.y, grid.y + r )
# cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
gridx = tl.program_id(0).to(tl.int64)
gridy = tl.program_id(1).to(tl.int64)
n_chunks = tl.num_programs(0)

for row in range(gridy * r, min((gridy + 1) * r, R)):
curr_cumsum = tl.zeros((1,), tl.float32)
row_offset = row * r_stride
cols = gridx * t * TILE + tl.arange(0, TILE)
for ti in range(0, t):
cols_offset = cols * k_stride
x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
if x.dtype.is_fp16() | x.dtype.is_bf16():
x = x.to(tl.float32)
tile_sum = tl.sum(x, 0)[None]
tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
curr_cumsum += tile_sum
if HAS_OUT_LAYOUT:
cols_offset = cols * out_k_stride
row_offset = row * out_r_stride
tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
if OUTPUT_SUMS:
tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
cols += TILE
if NORMALIZE:
cols = gridx * t * TILE + tl.arange(0, TILE)
for _ in range(0, t):
cols_offset = cols * k_stride
if HAS_OUT_LAYOUT:
cols_offset = cols * out_k_stride
row_offset = row * out_r_stride
x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
if x.dtype.is_fp16() | x.dtype.is_bf16():
x = x.to(tl.float32)
x = x / curr_cumsum
tl.store(out + row_offset + cols_offset, x, mask=cols < K)
cols += TILE


@libentry()
@triton.jit(
do_not_specialize=[
"r",
"t",
"R",
"K",
"r_stride",
"out_r_stride",
]
)
def block_update_kernel(
inp,
base,
rscale_ptr,
out,
r,
t,
R,
K,
r_stride,
k_stride,
out_r_stride,
out_k_stride,
rscale_stride,
HAS_OUT_LAYOUT: tl.constexpr,
TILE: tl.constexpr,
):
# One CTA processes a (r, t*tile) chunk
# rows = [ grid.y, grid.y + r )
# cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
gridx = tl.program_id(0).to(tl.int64)
gridy = tl.program_id(1).to(tl.int64)
n_gridx = tl.num_programs(1)

base += gridy * n_gridx + gridx
rscale_ptr += gridy * rscale_stride

for row in range(gridy, min(gridy + r, R)):
d = tl.load(base)
rscale = tl.load(rscale_ptr)
base += gridx
rscale_ptr += rscale_stride
row_offset = row * r_stride
cols = gridx * t * TILE + tl.arange(0, TILE)
for _ in range(0, t):
cols_offset = cols * k_stride
x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
x += d
x /= rscale
if HAS_OUT_LAYOUT:
cols_offset = cols * out_k_stride
row_offset = row * out_r_stride
tl.store(out + row_offset + cols_offset, x, mask=cols < K)
cols += TILE


GRID_Y_LIMIT = 65535


def normed_cumsum(inp, dim=-1):
logging.debug("GEMS NORMED_CUMSUM")
assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
dim = dim % inp.ndim
N = inp.numel()
K = inp.size(dim)
# inp = inp.contiguous()
# First and last dims are easier to handle, but transpose the middle dim to the last
ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
if is_mid_dim:
inp = inp.transpose(dim, -1).contiguous()
dim = -1
out = torch.empty_like(inp)
with torch.cuda.device(inp.device.index):
# Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
TILE = 2048
# Each row is split into n_chunks of chunks where each chunk is compised of
# n_tiles of tiles. Different chunks are assigned to different ctas.
n_rows = N // K
n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
k_stride = inp.stride(dim)
r_stride = inp.size(dim) if k_stride == 1 else 1
if n_rows > GRID_Y_LIMIT:
batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
n_batch = triton.cdiv(n_rows, batch)
else:
batch = 1
n_batch = n_rows

grid = (n_chunks, n_batch)
if n_chunks == 1:
block_cumsum_kernel[grid](
inp,
out,
0,
batch,
n_tiles,
n_rows,
K,
r_stride,
k_stride,
r_stride,
k_stride,
OUTPUT_SUMS=False,
NORMALIZE=True,
HAS_OUT_LAYOUT=False,
TILE=TILE,
)
return out

if inp.dtype != torch.float64:
acc_dtype = torch.float32
sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device="cuda")
cumsums = torch.empty_like(sums)
block_cumsum_kernel[grid](
inp,
out,
sums,
batch,
n_tiles,
n_rows,
K,
r_stride,
k_stride,
r_stride,
k_stride,
OUTPUT_SUMS=True,
NORMALIZE=False,
HAS_OUT_LAYOUT=False,
TILE=TILE,
)
# Pass two, scan partial cumsums
block_cumsum_kernel[(1, n_batch)](
sums,
cumsums,
0,
batch,
1,
n_rows,
n_chunks,
n_chunks,
1,
n_chunks,
1,
OUTPUT_SUMS=False,
NORMALIZE=False,
HAS_OUT_LAYOUT=True,
TILE=TILE,
)
# print(sums)
rscale = cumsums[..., -1]
block_update_kernel[grid](
out,
cumsums - sums,
rscale,
out,
batch,
n_tiles,
n_rows,
K,
r_stride,
k_stride,
r_stride,
k_stride,
n_chunks,
HAS_OUT_LAYOUT=False,
TILE=TILE,
)
return out
Loading