Skip to content

Commit

Permalink
doc: update documentation index (#603)
Browse files Browse the repository at this point in the history
Lots of APIs are missing from the documentation, this PR fixes the
index.
  • Loading branch information
yzh119 authored Nov 11, 2024
1 parent 595cf60 commit d305d56
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 4 deletions.
18 changes: 18 additions & 0 deletions docs/api/python/activation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.. _apiactivation:

flashinfer.activation
=====================

.. currentmodule:: flashinfer.activation

This module provides a set of activation operations for up/gate layers in transformer MLPs.

Up/Gate output activation
-------------------------

.. autosummary::
:toctree: ../../generated

silu_and_mul
gelu_tanh_and_mul
gelu_and_mul
3 changes: 3 additions & 0 deletions docs/api/python/norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ Kernels for normalization layers.
:toctree: _generate

rmsnorm
fused_add_rmsnorm
gemma_rmsnorm
gemma_fused_add_rmsnorm
1 change: 1 addition & 0 deletions docs/api/python/page.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Append new K/V tensors to Paged KV-Cache
:toctree: ../../generated

append_paged_kv_cache
get_batch_indices_positions
6 changes: 6 additions & 0 deletions docs/api/python/rope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ Kernels for applying rotary embeddings.
apply_llama31_rope_inplace
apply_rope
apply_llama31_rope
apply_rope_pos_ids
apply_rope_pos_ids_inplace
apply_llama31_rope_pos_ids
apply_llama31_rope_pos_ids_inplace
apply_rope_with_cos_sin_cache
apply_rope_with_cos_sin_cache_inplace
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ FlashInfer is a library for Large Language Models that provides high-performance
api/python/gemm
api/python/norm
api/python/rope
api/python/activation
api/python/quantization
6 changes: 6 additions & 0 deletions python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused SiLU and Mul operation.
``silu(input[..., :hidden_size]) * input[..., hidden_size:]``
Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -141,6 +143,8 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused GeLU Tanh and Mul operation.
``gelu(tanh(input[..., :hidden_size])) * input[..., hidden_size:]``
Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -171,6 +175,8 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused GeLU and Mul operation.
``gelu(input[..., :hidden_size]) * input[..., hidden_size:]``
Parameters
----------
input: torch.Tensor
Expand Down
20 changes: 18 additions & 2 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def rmsnorm(
) -> torch.Tensor:
r"""Root mean square normalization.
``out[i] = (input[i] / RMS(input)) * weight[i]``
Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -92,6 +94,12 @@ def fused_add_rmsnorm(
) -> None:
r"""Fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -119,7 +127,9 @@ def gemma_rmsnorm(
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Gemma Root mean square normalization.
r"""Gemma-style root mean square normalization.
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
Parameters
----------
Expand Down Expand Up @@ -163,7 +173,13 @@ def _gemma_rmsnorm_fake(
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
r"""Gemma Fused add root mean square normalization.
r"""Gemma-style fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
Parameters
----------
Expand Down
12 changes: 10 additions & 2 deletions python/flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,15 @@ def get_batch_indices_positions(
>>> positions # the rightmost column index of each row
tensor([4, 3, 4, 2, 3, 4, 1, 2, 3, 4], device='cuda:0', dtype=torch.int32)
Notes
-----
Note
----
This function is similar to `CSR2COO <https://docs.nvidia.com/cuda/cusparse/#csr2coo>`_
conversion in cuSPARSE library, with the difference that we are converting from a ragged
tensor (which don't require a column indices array) to a COO format.
See Also
--------
append_paged_kv_cache
"""
batch_size = append_indptr.size(0) - 1
batch_indices = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32)
Expand Down Expand Up @@ -305,6 +309,10 @@ def append_paged_kv_cache(
The function assumes that the space for appended k/v have already been allocated,
which means :attr:`kv_indices`, :attr:`kv_indptr`, :attr:`kv_last_page_len` has
incorporated appended k/v.
See Also
--------
get_batch_indices_positions
"""
_check_kv_layout(kv_layout)
_append_paged_kv_cache_kernel(
Expand Down

0 comments on commit d305d56

Please sign in to comment.