Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add xpu support #396

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
)

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.utils import infer_device


device = infer_device()

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

Expand Down Expand Up @@ -66,7 +70,6 @@ def bench_memory_fused_linear_cpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -107,8 +110,6 @@ def bench_speed_fused_linear_cpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)

Expand Down
12 changes: 8 additions & 4 deletions benchmark/scripts/benchmark_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
)

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.utils import infer_device


device = infer_device()


def bench_memory_cross_entropy(
Expand All @@ -24,8 +28,8 @@ def bench_memory_cross_entropy(
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda")
target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1)
_input = torch.randn(B * T, V, requires_grad=True, device=device)
target = torch.randint(V, (B * T, 1), device=device).squeeze(1)

def fwd():
if provider == "liger":
Expand Down Expand Up @@ -57,8 +61,8 @@ def bench_speed_cross_entropy(
B = input.extra_benchmark_config["B"]
T = input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda")
target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1)
_input = torch.randn(B * T, V, requires_grad=True, device=device)
target = torch.randint(V, (B * T, 1), device=device).squeeze(1)

def fwd():
if provider == "liger":
Expand Down
6 changes: 4 additions & 2 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
)

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.utils import infer_device


device = infer_device()


class TorchDPOLoss(torch.nn.Module):
Expand Down Expand Up @@ -79,7 +83,6 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

device = "cuda"
torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
Expand Down Expand Up @@ -127,7 +130,6 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"
torch_dpo_loss = TorchDPOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
Expand Down
8 changes: 4 additions & 4 deletions benchmark/scripts/benchmark_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
)

from liger_kernel.transformers.experimental.embedding import LigerEmbedding
from liger_kernel.utils import infer_device


device = infer_device()

# NOTE: For torch compile, we will just use default inductor settings. No further customization
# is needed.
Expand All @@ -26,8 +30,6 @@ def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
D = input.extra_benchmark_config["D"]
dtype = input.extra_benchmark_config["dtype"]

device = "cuda"

torch_emb = Embedding(V, D).to(device).to(dtype)
liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
torch_compile_emb = torch.compile(torch_emb)
Expand Down Expand Up @@ -68,8 +70,6 @@ def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
D = input.extra_benchmark_config["D"]
dtype = input.extra_benchmark_config["dtype"]

device = "cuda"

torch_emb = Embedding(V, D).to(device).to(dtype)
liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
torch_compile_emb = torch.compile(torch_emb)
Expand Down
7 changes: 4 additions & 3 deletions benchmark/scripts/benchmark_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from liger_kernel.utils import infer_device


device = infer_device()


class TorchLMHeadCE(torch.nn.Module):
Expand Down Expand Up @@ -65,7 +69,6 @@ def bench_memory_fused_linear_cross_entropy(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -105,8 +108,6 @@ def bench_speed_fused_linear_cross_entropy(
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)

Expand Down
6 changes: 4 additions & 2 deletions benchmark/scripts/benchmark_fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
)

from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD
from liger_kernel.utils import infer_device


device = infer_device()


class TorchJSD(torch.nn.Module):
Expand Down Expand Up @@ -134,7 +138,6 @@ def bench_memory_fused_linear_jsd(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)

Expand Down Expand Up @@ -183,7 +186,6 @@ def bench_speed_fused_linear_jsd(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)

Expand Down
6 changes: 4 additions & 2 deletions benchmark/scripts/benchmark_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
)

from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.utils import infer_device


device = infer_device()


def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand All @@ -31,7 +35,6 @@ def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
)

x_shape = (bsz, seq_len, hidden_size)
device = "cuda"

# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
Expand Down Expand Up @@ -99,7 +102,6 @@ def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp
)

x_shape = (bsz, seq_len, hidden_size)
device = "cuda"
# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)

Expand Down
16 changes: 10 additions & 6 deletions benchmark/scripts/benchmark_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
)

from liger_kernel.transformers.group_norm import LigerGroupNorm
from liger_kernel.utils import infer_device


device = infer_device()


def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand All @@ -26,12 +30,12 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(
num_channels=C, num_groups=C // channels_per_group, eps=eps
).to("cuda")
).to(device)
torch_ln = torch.nn.GroupNorm(
num_groups=C // channels_per_group, num_channels=C, eps=eps
).to("cuda")
).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down Expand Up @@ -83,12 +87,12 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu
x_shape = (M, C, H)
triton_ln = LigerGroupNorm(
num_channels=C, num_groups=C // channels_per_group, eps=eps
).to("cuda")
).to(device)
torch_ln = torch.nn.GroupNorm(
num_groups=C // channels_per_group, num_channels=C, eps=eps
).to("cuda")
).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down
12 changes: 8 additions & 4 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
)

from liger_kernel.transformers.jsd import LigerJSD
from liger_kernel.utils import infer_device


device = infer_device()


class TorchJSD(torch.nn.Module):
Expand Down Expand Up @@ -56,10 +60,10 @@ def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down Expand Up @@ -101,10 +105,10 @@ def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1)
target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down
12 changes: 8 additions & 4 deletions benchmark/scripts/benchmark_kl_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
)

from liger_kernel.transformers.kl_div import LigerKLDIVLoss
from liger_kernel.utils import infer_device


device = infer_device()

S, E = 12, 18

Expand All @@ -22,10 +26,10 @@ def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
torch_kl_div = nn.KLDivLoss(reduction=reduction)
liger_kl_div = LigerKLDIVLoss(reduction=reduction)

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down Expand Up @@ -68,10 +72,10 @@ def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
_input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
target = torch.randn(B * T, V, device=device).softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
Expand Down
16 changes: 10 additions & 6 deletions benchmark/scripts/benchmark_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
)

from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.utils import infer_device


device = infer_device()


def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand All @@ -22,10 +26,10 @@ def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun
dtype = extra_benchmark_config["dtype"]

x_shape = (M, N)
triton_ln = LigerLayerNorm(hidden_size=N).to("cuda")
torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda")
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down Expand Up @@ -73,10 +77,10 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu

x_shape = (M, N)

triton_ln = LigerLayerNorm(hidden_size=N).to("cuda")
torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda")
triton_ln = LigerLayerNorm(hidden_size=N).to(device)
torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)

x = torch.randn(x_shape, dtype=dtype, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

Expand Down
7 changes: 4 additions & 3 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
)

from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.utils import infer_device


device = infer_device()

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

Expand Down Expand Up @@ -66,7 +70,6 @@ def bench_memory_fused_linear_orpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)

Expand Down Expand Up @@ -107,8 +110,6 @@ def bench_speed_fused_linear_orpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)

Expand Down
Loading
Loading