From 4ced0426f386c533e9a5f510cc626706a1b89c46 Mon Sep 17 00:00:00 2001 From: Golam Rabbani Date: Tue, 19 Nov 2024 08:28:09 -0800 Subject: [PATCH 1/5] add xpu support --- benchmark/scripts/benchmark_cross_entropy.py | 12 ++++-- benchmark/scripts/benchmark_dpo_loss.py | 6 ++- benchmark/scripts/benchmark_embedding.py | 8 ++-- .../benchmark_fused_linear_cross_entropy.py | 7 ++-- .../scripts/benchmark_fused_linear_jsd.py | 6 ++- benchmark/scripts/benchmark_geglu.py | 6 ++- benchmark/scripts/benchmark_group_norm.py | 16 +++++--- benchmark/scripts/benchmark_jsd.py | 12 ++++-- benchmark/scripts/benchmark_kl_div.py | 12 ++++-- benchmark/scripts/benchmark_layer_norm.py | 16 +++++--- benchmark/scripts/benchmark_orpo_loss.py | 7 ++-- benchmark/scripts/benchmark_rms_norm.py | 16 +++++--- benchmark/scripts/benchmark_rope.py | 28 +++++++------ benchmark/scripts/benchmark_swiglu.py | 6 ++- benchmark/scripts/utils.py | 12 ++++-- examples/huggingface/callback.py | 8 ++-- examples/lightning/training.py | 6 ++- examples/medusa/callback.py | 10 +++-- src/liger_kernel/__init__.py | 0 src/liger_kernel/ops/layer_norm.py | 7 +++- src/liger_kernel/ops/rms_norm.py | 1 + src/liger_kernel/ops/utils.py | 6 ++- src/liger_kernel/utils.py | 13 +++++++ test/chunked_loss/test_dpo_loss.py | 11 ++++-- test/chunked_loss/test_orpo_loss.py | 11 ++++-- test/convergence/test_mini_models.py | 6 ++- .../test_mini_models_multimodal.py | 6 ++- .../test_mini_models_with_logits.py | 6 ++- test/transformers/test_cross_entropy.py | 39 ++++++++++--------- test/transformers/test_embedding.py | 5 ++- .../test_fused_linear_cross_entropy.py | 11 +++--- test/transformers/test_fused_linear_jsd.py | 11 ++---- test/transformers/test_geglu.py | 19 +++++---- test/transformers/test_group_norm.py | 9 +++-- test/transformers/test_jsd.py | 13 ++++--- test/transformers/test_kl_div.py | 5 ++- test/transformers/test_layer_norm.py | 15 ++++--- test/transformers/test_mm_int8int2.py | 7 +++- test/transformers/test_rms_norm.py | 5 ++- test/transformers/test_rope.py | 23 ++++++----- test/transformers/test_swiglu.py | 29 +++++++------- test/utils.py | 26 ++++--------- 42 files changed, 291 insertions(+), 187 deletions(-) create mode 100644 src/liger_kernel/__init__.py create mode 100644 src/liger_kernel/utils.py diff --git a/benchmark/scripts/benchmark_cross_entropy.py b/benchmark/scripts/benchmark_cross_entropy.py index d6dffbf7e..4409eb616 100644 --- a/benchmark/scripts/benchmark_cross_entropy.py +++ b/benchmark/scripts/benchmark_cross_entropy.py @@ -11,6 +11,10 @@ ) from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss +from liger_kernel.utils import infer_device + + +device = infer_device() def bench_memory_cross_entropy( @@ -24,8 +28,8 @@ def bench_memory_cross_entropy( B = input.extra_benchmark_config["B"] T = input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda") - target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1) + _input = torch.randn(B * T, V, requires_grad=True, device=device) + target = torch.randint(V, (B * T, 1), device=device).squeeze(1) def fwd(): if provider == "liger": @@ -57,8 +61,8 @@ def bench_speed_cross_entropy( B = input.extra_benchmark_config["B"] T = input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda") - target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1) + _input = torch.randn(B * T, V, requires_grad=True, device=device) + target = torch.randint(V, (B * T, 1), device=device).squeeze(1) def fwd(): if provider == "liger": diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index 537be47bc..5ab1c3f44 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -12,6 +12,10 @@ ) from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.utils import infer_device + + +device = infer_device() class TorchDPOLoss(torch.nn.Module): @@ -79,7 +83,6 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO ignore_index = input.extra_benchmark_config["ignore_index"] provider = input.kernel_provider - device = "cuda" torch_dpo_loss = TorchDPOLoss( H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) @@ -127,7 +130,6 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" torch_dpo_loss = TorchDPOLoss( H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) diff --git a/benchmark/scripts/benchmark_embedding.py b/benchmark/scripts/benchmark_embedding.py index 1f20aec35..aa49ef3c7 100644 --- a/benchmark/scripts/benchmark_embedding.py +++ b/benchmark/scripts/benchmark_embedding.py @@ -11,6 +11,10 @@ ) from liger_kernel.transformers.experimental.embedding import LigerEmbedding +from liger_kernel.utils import infer_device + + +device = infer_device() # NOTE: For torch compile, we will just use default inductor settings. No further customization # is needed. @@ -26,8 +30,6 @@ def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO D = input.extra_benchmark_config["D"] dtype = input.extra_benchmark_config["dtype"] - device = "cuda" - torch_emb = Embedding(V, D).to(device).to(dtype) liger_emb = LigerEmbedding(V, D).to(device).to(dtype) torch_compile_emb = torch.compile(torch_emb) @@ -68,8 +70,6 @@ def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun D = input.extra_benchmark_config["D"] dtype = input.extra_benchmark_config["dtype"] - device = "cuda" - torch_emb = Embedding(V, D).to(device).to(dtype) liger_emb = LigerEmbedding(V, D).to(device).to(dtype) torch_compile_emb = torch.compile(torch_emb) diff --git a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py index eaceeed03..b7f6721ea 100644 --- a/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +++ b/benchmark/scripts/benchmark_fused_linear_cross_entropy.py @@ -12,6 +12,10 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.utils import infer_device + + +device = infer_device() class TorchLMHeadCE(torch.nn.Module): @@ -65,7 +69,6 @@ def bench_memory_fused_linear_cross_entropy( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) @@ -105,8 +108,6 @@ def bench_speed_fused_linear_cross_entropy( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index 7f652de8a..d112529ee 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -10,6 +10,10 @@ ) from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD +from liger_kernel.utils import infer_device + + +device = infer_device() class TorchJSD(torch.nn.Module): @@ -134,7 +138,6 @@ def bench_memory_fused_linear_jsd( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) @@ -183,7 +186,6 @@ def bench_speed_fused_linear_jsd( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device) diff --git a/benchmark/scripts/benchmark_geglu.py b/benchmark/scripts/benchmark_geglu.py index 81611de3f..19ce166f1 100644 --- a/benchmark/scripts/benchmark_geglu.py +++ b/benchmark/scripts/benchmark_geglu.py @@ -12,6 +12,10 @@ ) from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.utils import infer_device + + +device = infer_device() def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -31,7 +35,6 @@ def bench_speed_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) @@ -99,7 +102,6 @@ def bench_memory_geglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py index 595d379f8..f3562e025 100644 --- a/benchmark/scripts/benchmark_group_norm.py +++ b/benchmark/scripts/benchmark_group_norm.py @@ -10,6 +10,10 @@ ) from liger_kernel.transformers.group_norm import LigerGroupNorm +from liger_kernel.utils import infer_device + + +device = infer_device() def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -26,12 +30,12 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun x_shape = (M, C, H) triton_ln = LigerGroupNorm( num_channels=C, num_groups=C // channels_per_group, eps=eps - ).to("cuda") + ).to(device) torch_ln = torch.nn.GroupNorm( num_groups=C // channels_per_group, num_channels=C, eps=eps - ).to("cuda") + ).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) @@ -83,12 +87,12 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu x_shape = (M, C, H) triton_ln = LigerGroupNorm( num_channels=C, num_groups=C // channels_per_group, eps=eps - ).to("cuda") + ).to(device) torch_ln = torch.nn.GroupNorm( num_groups=C // channels_per_group, num_channels=C, eps=eps - ).to("cuda") + ).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) diff --git a/benchmark/scripts/benchmark_jsd.py b/benchmark/scripts/benchmark_jsd.py index 272008315..bff8398ab 100644 --- a/benchmark/scripts/benchmark_jsd.py +++ b/benchmark/scripts/benchmark_jsd.py @@ -10,6 +10,10 @@ ) from liger_kernel.transformers.jsd import LigerJSD +from liger_kernel.utils import infer_device + + +device = infer_device() class TorchJSD(torch.nn.Module): @@ -56,10 +60,10 @@ def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: torch_jsd = TorchJSD() liger_jsd = LigerJSD() - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": @@ -101,10 +105,10 @@ def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput V = input.x B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1) + target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": diff --git a/benchmark/scripts/benchmark_kl_div.py b/benchmark/scripts/benchmark_kl_div.py index c446c7ae2..6062627c0 100644 --- a/benchmark/scripts/benchmark_kl_div.py +++ b/benchmark/scripts/benchmark_kl_div.py @@ -11,6 +11,10 @@ ) from liger_kernel.transformers.kl_div import LigerKLDIVLoss +from liger_kernel.utils import infer_device + + +device = infer_device() S, E = 12, 18 @@ -22,10 +26,10 @@ def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu torch_kl_div = nn.KLDivLoss(reduction=reduction) liger_kl_div = LigerKLDIVLoss(reduction=reduction) - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": @@ -68,10 +72,10 @@ def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp V = input.x B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] - _input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax( + _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( dim=-1 ) - target = torch.randn(B * T, V, device="cuda").softmax(dim=-1) + target = torch.randn(B * T, V, device=device).softmax(dim=-1) def fwd(): if input.kernel_provider == "liger": diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 89f07c640..a6b46c3f7 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -10,6 +10,10 @@ ) from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.utils import infer_device + + +device = infer_device() def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -22,10 +26,10 @@ def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun dtype = extra_benchmark_config["dtype"] x_shape = (M, N) - triton_ln = LigerLayerNorm(hidden_size=N).to("cuda") - torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda") + triton_ln = LigerLayerNorm(hidden_size=N).to(device) + torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) @@ -73,10 +77,10 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu x_shape = (M, N) - triton_ln = LigerLayerNorm(hidden_size=N).to("cuda") - torch_ln = torch.nn.LayerNorm(N, eps=eps).to("cuda") + triton_ln = LigerLayerNorm(hidden_size=N).to(device) + torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py index dda42d772..0669370b5 100644 --- a/benchmark/scripts/benchmark_orpo_loss.py +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -13,6 +13,10 @@ ) from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction +from liger_kernel.utils import infer_device + + +device = infer_device() sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -66,7 +70,6 @@ def bench_memory_fused_linear_orpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) @@ -107,8 +110,6 @@ def bench_speed_fused_linear_orpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) diff --git a/benchmark/scripts/benchmark_rms_norm.py b/benchmark/scripts/benchmark_rms_norm.py index 46734504e..64fa7072a 100644 --- a/benchmark/scripts/benchmark_rms_norm.py +++ b/benchmark/scripts/benchmark_rms_norm.py @@ -11,6 +11,10 @@ ) from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.utils import infer_device + + +device = infer_device() class LlamaRMSNorm(nn.Module): @@ -42,10 +46,10 @@ def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu x_shape = (M, N) - triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to("cuda") - llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to("cuda") + triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) + llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) @@ -104,10 +108,10 @@ def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO x_shape = (M, N) - triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to("cuda") - llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to("cuda") + triton_rms = LigerRMSNorm(hidden_size=N, eps=eps).to(device) + llama_rms = LlamaRMSNorm(hidden_size=N, eps=eps).to(device) - x = torch.randn(x_shape, dtype=dtype, device="cuda") + x = torch.randn(x_shape, dtype=dtype, device=device) dy = torch.randn_like(x) x.requires_grad_(True) diff --git a/benchmark/scripts/benchmark_rope.py b/benchmark/scripts/benchmark_rope.py index 265fe703a..5433cb928 100644 --- a/benchmark/scripts/benchmark_rope.py +++ b/benchmark/scripts/benchmark_rope.py @@ -14,6 +14,10 @@ ) from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.utils import infer_device + + +device = infer_device() def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -38,23 +42,23 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput ) head_dim = hidden_size // num_q_heads - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k, pos_ids) def fwd(): @@ -122,23 +126,23 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu ) head_dim = hidden_size // num_q_heads - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k, pos_ids) def full(): diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index 08689d24e..07332c83b 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -12,6 +12,10 @@ ) from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +from liger_kernel.utils import infer_device + + +device = infer_device() def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: @@ -33,7 +37,6 @@ def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutp ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) @@ -103,7 +106,6 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut ) x_shape = (bsz, seq_len, hidden_size) - device = "cuda" # initialize input x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py index 1d147b51b..ddd4d05d0 100644 --- a/benchmark/scripts/utils.py +++ b/benchmark/scripts/utils.py @@ -10,6 +10,9 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch +from liger_kernel.utils import infer_device + +device = infer_device() LIGER_KERNEL_VERSION = version("liger-kernel") @@ -88,10 +91,10 @@ def _test_memory( total_mem = [] for _ in range(_iter): - torch.cuda.memory.reset_peak_memory_stats() + getattr(torch, device).memory.reset_peak_memory_stats() func() # Convert to MB - mem = torch.cuda.max_memory_allocated() / 2**20 + mem = getattr(torch, device).max_memory_allocated() / 2**20 total_mem.append(mem) total_mem = torch.tensor(total_mem, dtype=torch.float) @@ -141,8 +144,9 @@ def get_gpu_name(): """ Returns the current GPU name, formatted to serve as a directory name """ - if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) + torch_device = getattr(torch, device) + if torch_device.is_available(): + gpu_name = torch_device.get_device_name(torch_device.current_device()) return gpu_name else: raise Exception("Benchmarks can only be run on GPU.") diff --git a/examples/huggingface/callback.py b/examples/huggingface/callback.py index 9582c81fd..6bf5412d6 100644 --- a/examples/huggingface/callback.py +++ b/examples/huggingface/callback.py @@ -4,6 +4,7 @@ import torch import transformers from transformers import TrainerControl, TrainerState, TrainingArguments +from liger_kernel.utils import infer_device # https://simple.wikipedia.org/wiki/Byte # For memory, we use binary system @@ -111,6 +112,7 @@ def __init__( self.time = Time() self.memory = Memory() self.tps = TPS() + self.device = infer_device() def on_init_end( self, @@ -171,7 +173,7 @@ def on_step_begin( several inputs. """ # memory - torch.cuda.reset_peak_memory_stats() + getattr(torch, self.device).reset_peak_memory_stats() # time self.state.step_start_time = time.perf_counter() @@ -218,8 +220,8 @@ def on_step_end( ) # memory - step_peak_memory_allocated = torch.cuda.memory.max_memory_allocated() - step_peak_memory_reserved = torch.cuda.memory.max_memory_reserved() + step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated() + step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved() self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory diff --git a/examples/lightning/training.py b/examples/lightning/training.py index f70e9aac1..c46f6cbd7 100644 --- a/examples/lightning/training.py +++ b/examples/lightning/training.py @@ -15,6 +15,7 @@ from trl import DataCollatorForCompletionOnlyLM from liger_kernel.transformers import AutoLigerKernelForCausalLM +from liger_kernel.utils import infer_device _RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"} QUESTION = "" @@ -263,10 +264,11 @@ def train(): strategy = "auto" precision = "bf16-true" + device = infer_device() trainer = pl.Trainer( - accelerator="cuda", + accelerator=device, strategy=strategy, - devices=torch.cuda.device_count() if args.num_gpu is None else args.num_gpu, + devices=getattr(torch, device).device_count() if args.num_gpu is None else args.num_gpu, default_root_dir=args.output_dir, log_every_n_steps=1, max_epochs=1, diff --git a/examples/medusa/callback.py b/examples/medusa/callback.py index ef4c38f1e..fb35c1cc3 100644 --- a/examples/medusa/callback.py +++ b/examples/medusa/callback.py @@ -6,6 +6,7 @@ import transformers from accelerate.utils.constants import FSDP_SHARDING_STRATEGY from transformers import TrainerControl, TrainerState, TrainingArguments +from liger_kernel.utils import infer_device # https://simple.wikipedia.org/wiki/Byte # For memory, we use binary system @@ -137,6 +138,7 @@ def __init__( self.memory = Memory() self.tps = TPS() self.mfu = MFU() + self.device = infer_device() def on_init_end( self, @@ -198,7 +200,7 @@ def on_step_begin( several inputs. """ # memory - torch.cuda.reset_peak_memory_stats() + getattr(torch, self.device).reset_peak_memory_stats() # time self.state.step_start_time = time.perf_counter() @@ -247,8 +249,8 @@ def on_step_end( ) # memory - step_peak_memory_allocated = torch.cuda.memory.max_memory_allocated() - step_peak_memory_reserved = torch.cuda.memory.max_memory_reserved() + step_peak_memory_allocated = getattr(torch, self.device).memory.max_memory_allocated() + step_peak_memory_reserved = getattr(torch, self.device).memory.max_memory_reserved() self.memory.step_peak_memory_allocated_MB = round_to_n_decimal( step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory @@ -381,7 +383,7 @@ def _get_gpu_peak_tflops(precision_bits: int = 16): if precision_bits not in {16, 32}: raise Exception(f"Precision bits {precision_bits} is not supported") - device_name = torch.cuda.get_device_name() + device_name = getattr(torch, infer_device()).get_device_name() if "A100" in device_name: # data from https://www.nvidia.com/en-us/data-center/a100/ diff --git a/src/liger_kernel/__init__.py b/src/liger_kernel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py index 75df1f6ba..70c372237 100644 --- a/src/liger_kernel/ops/layer_norm.py +++ b/src/liger_kernel/ops/layer_norm.py @@ -180,8 +180,13 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD): dY = dY.view(-1, dim) n_rows, n_cols = dY.shape + sm_count = 1 + if X.device.type == "cuda": + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + elif X.device.type == "xpu": + sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index 633a3275b..fff199a93 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -264,6 +264,7 @@ def rms_norm_backward( dY = dY.view(-1, dim) n_rows, n_cols = dY.shape + sm_count = 1 if X.device.type == "cuda": sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count elif X.device.type == "xpu": diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index 4a24223d0..03b721d26 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -19,6 +19,7 @@ import triton import triton.language as tl from packaging.version import Version +from liger_kernel.utils import infer_device def is_hip() -> bool: @@ -69,10 +70,11 @@ def compare_version(package: str, operator: Callable, target: str): def get_amp_custom_fwd_bwd() -> Callable: + device = infer_device() if compare_version("torch", operator.ge, "2.4.0"): return ( - functools.partial(torch.amp.custom_fwd, device_type="cuda"), - functools.partial(torch.amp.custom_bwd, device_type="cuda"), + functools.partial(torch.amp.custom_fwd, device_type=device), + functools.partial(torch.amp.custom_bwd, device_type=device), ) return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd diff --git a/src/liger_kernel/utils.py b/src/liger_kernel/utils.py new file mode 100644 index 000000000..0a6d5feba --- /dev/null +++ b/src/liger_kernel/utils.py @@ -0,0 +1,13 @@ +import torch + + +def infer_device(): + """ + Get current device name based on available devices + """ + if torch.cuda.is_available(): + return "cuda" + elif torch.xpu.is_available(): + return "xpu" + else: + return "cpu" diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 7f4eef053..bd7509734 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -5,6 +5,9 @@ import torch.nn.functional as F from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -58,7 +61,7 @@ def alignment_loss( def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): B = 2 * B # dpo loss requires B to be even - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -69,7 +72,7 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -77,11 +80,11 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 5e532938b..4f13c3715 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -6,6 +6,9 @@ import torch.nn.functional as F from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -76,7 +79,7 @@ def alignment_loss( def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): B = 2 * B # orpo loss requires B to be even - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -87,7 +90,7 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -95,11 +98,11 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 5c30349ae..051effcfa 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -60,6 +60,10 @@ except ImportError: QWEN2_VL_AVAILABLE = False +from liger_kernel.utils import infer_device + +device = infer_device() + MINI_MODEL_SETUPS = { "mini_llama3": MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_llama, @@ -427,7 +431,7 @@ def run_mini_model( else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to("cuda") + model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader( diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index bb9d8e712..07ddd9493 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -58,6 +58,10 @@ except ImportError: MLLAMA_AVAILABLE = False +from liger_kernel.utils import infer_device + +device = infer_device() + torch.use_deterministic_algorithms(True) # Only setting torch.use_deterministic_algorithms(True) throws the following error: @@ -333,7 +337,7 @@ def run_mini_model_multimodal( else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to("cuda") + model = create_model(model_name).to(dtype).to(device) model.gradient_checkpointing_enable() train_dataset = create_multimodal_dataset(model_name) diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index 0b183e3d3..e7672c4a4 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -60,6 +60,10 @@ except ImportError: QWEN2_VL_AVAILABLE = False +from liger_kernel.utils import infer_device + +device = infer_device() + MINI_MODEL_SETUPS = { "mini_llama3": MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_llama, @@ -427,7 +431,7 @@ def run_mini_model( else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to("cuda") + model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader( train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 6ec73a1a3..10114ea06 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -8,7 +8,10 @@ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy +from liger_kernel.utils import infer_device + +device = infer_device() set_seed(42) @@ -71,11 +74,11 @@ def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, r torch.manual_seed(0) torch_ce = CrossEntropyLoss(reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -92,11 +95,11 @@ def _test_correctness_with_ignore_index_once( torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( @@ -123,11 +126,11 @@ def _test_correctness_with_label_smoothing_once( torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -147,11 +150,11 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( ignore_index=ignore_index, label_smoothing=label_smoothing ) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( @@ -178,12 +181,12 @@ def _test_correctness_with_softcap_once( torch_ce = CrossEntropyLoss(reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar # upcasting to match liger's casting strategy _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # downcasting to original dtype output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) @@ -214,11 +217,11 @@ def _test_correctness_with_z_loss_once( dtype=dtype, ) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) if return_z_loss: output, z_output = torch_ce(_input, target) output2, z_output2 = target_ce(_input2, target) @@ -263,11 +266,11 @@ def _test_correctness_with_z_loss_with_other_params_once( dtype=dtype, ) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint( @@ -302,11 +305,11 @@ def _test_correctness_not_last_layer_once( torch_ce = CrossEntropyLoss(reduction=reduction) - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) output = torch_ce(_input, target) output2 = target_ce(_input2, target) @@ -330,12 +333,12 @@ def _test_correctness_functional( rtol, ): - _input = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B * T, V, device=device, dtype=dtype) * scalar x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) y1, y1_z = liger_cross_entropy( x1, diff --git a/test/transformers/test_embedding.py b/test/transformers/test_embedding.py index 998a544c5..416784d0f 100644 --- a/test/transformers/test_embedding.py +++ b/test/transformers/test_embedding.py @@ -3,6 +3,9 @@ from torch.nn import Embedding from liger_kernel.transformers.experimental.embedding import LigerEmbedding +from liger_kernel.utils import infer_device + +device = infer_device() SLEEP_SECONDS = 0.1 @@ -27,7 +30,7 @@ @pytest.mark.parametrize( "dtype, atol, rtol, device", [ - (torch.float32, 1e-6, 1e-5, "cuda"), + (torch.float32, 1e-6, 1e-5, device), ], ) def test_embedding_correctness( diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 881330c52..b69dcf492 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -12,6 +12,9 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.utils import infer_device + +device = infer_device() # set random seed globally set_seed() @@ -142,7 +145,6 @@ def test_correctness( atol, rtol, ): - device = "cuda" torch_lm_head_ce = TorchLMHeadCE( H=H, V=V, @@ -233,8 +235,6 @@ def test_correctness( ) @pytest.mark.parametrize("bias", [True, False]) def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): - device = "cuda" - _input = torch.randn(B * T, H, device=device, dtype=dtype) * scalar x1 = _input.detach().clone().requires_grad_(True) x2 = _input.detach().clone().requires_grad_(True) @@ -272,7 +272,6 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): ], ) def test_amp(B, T, H, V, cast_dtype, atol, rtol): - device = "cuda" dtype = torch.float32 torch_lm_head_ce = TorchLMHeadCE( H=H, @@ -302,13 +301,13 @@ def test_amp(B, T, H, V, cast_dtype, atol, rtol): target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) - with torch.autocast(device_type="cuda", dtype=cast_dtype): + with torch.autocast(device_type=device, dtype=cast_dtype): output1 = torch_lm_head_ce(_input1, target) output2 = liger_lm_head_ce(_input2, target) assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) - with torch.autocast(device_type="cuda", dtype=cast_dtype): + with torch.autocast(device_type=device, dtype=cast_dtype): output1.backward() output2.backward() diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py index 31a3ea103..06b7aab61 100644 --- a/test/transformers/test_fused_linear_jsd.py +++ b/test/transformers/test_fused_linear_jsd.py @@ -7,6 +7,9 @@ from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.transformers.functional import liger_fused_linear_jsd from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD +from liger_kernel.utils import infer_device + +device = infer_device() set_seed(42) @@ -108,7 +111,6 @@ def forward(self, student_input, teacher_input, label=None): ], ) def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -183,7 +185,6 @@ def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): def test_correctness_with_ignore_index( B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol ): - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -267,8 +268,6 @@ def test_correctness_with_ignore_index( def test_correctness_functional( B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol ): - device = "cuda" - # init the linear in all FusedLinearJSDs with the same weights _weight = torch.rand(V, H // 2, device=device, dtype=dtype) _weight1 = _weight.detach().clone().requires_grad_(True) @@ -346,7 +345,6 @@ def test_correctness_functional( def test_correctness_all_ignored( B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol ): - device = "cuda" torch_lm_head_jsd = TorchLMHeadJSD( H=H, V=V, @@ -411,7 +409,6 @@ def test_amp(autocast_dtype, atol, rtol): ignore_index = -100 temperature = 1.0 beta = 0.5 - device = "cuda" dtype = torch.float32 torch_lm_head_jsd = TorchLMHeadJSD( H=H, @@ -456,7 +453,7 @@ def test_amp(autocast_dtype, atol, rtol): ] # Randomly select indices label[indices_to_assign] = ignore_index - with torch.autocast(device_type="cuda", dtype=autocast_dtype): + with torch.autocast(device_type=device, dtype=autocast_dtype): output1 = torch_lm_head_jsd(_input1, teacher_input, label) output2 = liger_lm_head_jsd(_input2, teacher_input, label) diff --git a/test/transformers/test_geglu.py b/test/transformers/test_geglu.py index cf7c5a3c5..5067dfec0 100644 --- a/test/transformers/test_geglu.py +++ b/test/transformers/test_geglu.py @@ -8,6 +8,9 @@ from liger_kernel.ops.geglu import LigerGELUMulFunction from liger_kernel.transformers.functional import liger_geglu from liger_kernel.transformers.geglu import LigerGEGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() LLAMA_CONFIG = LlamaConfig( hidden_size=4096, @@ -42,22 +45,22 @@ ], ) def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol): - _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) # initialize weights - G = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - U = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) - llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype) llama_mlp.gate_proj.weight.data = G.T llama_mlp.up_proj.weight.data = U.T llama_mlp.down_proj.weight.data = D.T - liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + liger_mlp = LigerGEGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype) liger_mlp.gate_proj.weight.data = G.T liger_mlp.up_proj.weight.data = U.T liger_mlp.down_proj.weight.data = D.T @@ -121,8 +124,8 @@ def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, ], ) def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol): - _input = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) - _b = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) + _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 32419ed6a..4f53444d5 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -4,6 +4,9 @@ import torch from liger_kernel.transformers.group_norm import LigerGroupNorm +from liger_kernel.utils import infer_device + +device = infer_device() random_batch_size = random.randint(1, 16) random_num_groups = random.randint(1, 32) @@ -32,17 +35,17 @@ def test_liger_group_norm( torch.manual_seed(0) _tensor = torch.randn( - batch_size, num_channels, hidden_size, dtype=dtype, device="cuda" + batch_size, num_channels, hidden_size, dtype=dtype, device=device ) liger_x = _tensor.clone().detach().requires_grad_(True) torch_x = _tensor.clone().detach().requires_grad_(True) - liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() + liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device) torch_ln = ( torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6) .to(dtype) - .cuda() + .to(device) ) with torch.no_grad(): diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index 388b3a5c3..c2c116f97 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -7,6 +7,9 @@ from liger_kernel.transformers.functional import liger_jsd from liger_kernel.transformers.jsd import LigerJSD, LigerJSDFunction +from liger_kernel.utils import infer_device + +device = infer_device() set_seed(42) @@ -84,7 +87,7 @@ def _test_correctness_once( atol, rtol, is_last_layer=True, - device="cuda", + device=device, ): torch_jsd = JSD(dtype=dtype) @@ -126,7 +129,7 @@ def _test_correctness_with_beta_once( atol, rtol, is_last_layer=True, - device="cuda", + device=device, ): torch_jsd = JSD(beta=beta, dtype=dtype) @@ -163,7 +166,7 @@ def _test_correctness_with_ignore_index_once( dtype, atol, rtol, - device="cuda", + device=device, ): torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) @@ -198,7 +201,7 @@ def _test_correctness_with_ignore_index_once( def _test_correctness_functional( - B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device="cuda" + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device=device ): input = torch.randn( B * T, V, device=device, dtype=dtype, requires_grad=True @@ -292,7 +295,7 @@ def test_correctness_with_all_indices_ignored( dtype=torch.bfloat16, atol=1e-3, rtol=1e-3, - device="cuda", + device=device, ): ignore_index = -100 torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) diff --git a/test/transformers/test_kl_div.py b/test/transformers/test_kl_div.py index 5cc3eba6a..1f0c2d5ad 100644 --- a/test/transformers/test_kl_div.py +++ b/test/transformers/test_kl_div.py @@ -5,6 +5,9 @@ from torch.nn import KLDivLoss from liger_kernel.transformers.kl_div import LigerKLDIVLoss +from liger_kernel.utils import infer_device + +device = infer_device() _SHAPE_PARAMS = ( "B, T, V", @@ -43,7 +46,7 @@ def _test_correctness_once( reduction, log_target, is_last_layer=True, - device="cuda", + device=device, ): torch.manual_seed(0) torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target) diff --git a/test/transformers/test_layer_norm.py b/test/transformers/test_layer_norm.py index e47d40999..1d4e773ee 100644 --- a/test/transformers/test_layer_norm.py +++ b/test/transformers/test_layer_norm.py @@ -4,6 +4,9 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction from liger_kernel.transformers.functional import liger_layer_norm from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.utils import infer_device + +device = infer_device() @pytest.mark.parametrize( @@ -22,13 +25,13 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): torch.manual_seed(0) - x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) liger_x = x.clone().requires_grad_(True) torch_x = x.clone().requires_grad_(True) - liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).cuda() - torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).cuda() + liger_ln = LigerLayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) + torch_ln = torch.nn.LayerNorm(hidden_size, eps=1e-6).to(dtype).to(device) with torch.no_grad(): torch_ln.weight.copy_(liger_ln.weight) @@ -68,17 +71,17 @@ def test_liger_layer_norm_functional( ): torch.manual_seed(0) - input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device) x1 = input.clone().requires_grad_(True) x2 = input.clone().requires_grad_(True) - w = torch.randn(hidden_size, device="cuda", dtype=dtype) + w = torch.randn(hidden_size, device=device, dtype=dtype) w1 = w.clone().requires_grad_(True) w2 = w.clone().requires_grad_(True) - b = torch.randn(hidden_size, device="cuda", dtype=dtype) + b = torch.randn(hidden_size, device=device, dtype=dtype) b1 = b.clone().requires_grad_(True) b2 = b.clone().requires_grad_(True) diff --git a/test/transformers/test_mm_int8int2.py b/test/transformers/test_mm_int8int2.py index d7d13a958..a2458523a 100644 --- a/test/transformers/test_mm_int8int2.py +++ b/test/transformers/test_mm_int8int2.py @@ -6,6 +6,9 @@ pack_weights, unpack_weights, ) +from liger_kernel.utils import infer_device + +device = infer_device() # input_features = size*4 when the weight matrix is unpacked @@ -38,7 +41,7 @@ @pytest.mark.parametrize( "atol, rtol, device", [ - (1e-2, 1e-2, "cuda"), + (1e-2, 1e-2, device), ], ) def test_kernel_correctness( @@ -95,7 +98,7 @@ def test_kernel_correctness( @pytest.mark.parametrize( "device", [ - "cuda", + device, ], ) def test_unpack_pack_correctness(out_features, size, device): diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index fcc54b309..69e298f5a 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -13,10 +13,13 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.transformers.functional import liger_rms_norm from liger_kernel.transformers.rms_norm import LigerRMSNorm +from liger_kernel.utils import infer_device + +device = infer_device() set_seed(42) torch.use_deterministic_algorithms(True) -device = infer_device() + # Only setting torch.use_deterministic_algorithms(True) might throw the following error: # RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, # but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index cc852563d..d3fe9f127 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -10,6 +10,9 @@ from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.transformers.functional import liger_rope from liger_kernel.transformers.rope import liger_rotary_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() SLEEP_SECONDS = 0.1 @@ -46,16 +49,16 @@ def test_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol ): - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) _tensor_q = ( - torch.randn((bsz, seq_len, num_q_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) _tensor_k = ( - torch.randn((bsz, seq_len, num_kv_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) @@ -66,7 +69,7 @@ def test_correctness( q2 = _tensor_q.clone().requires_grad_(True) k2 = _tensor_k.clone().requires_grad_(True) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -77,8 +80,8 @@ def test_correctness( # validate backward pass dq, dk = ( - torch.randn_like(hf_q, device="cuda"), - torch.randn_like(hf_k, device="cuda").to(dtype), + torch.randn_like(hf_q, device=device), + torch.randn_like(hf_k, device=device).to(dtype), ) q1_grad, k1_grad = torch.autograd.grad( @@ -111,8 +114,8 @@ def test_correctness( def test_functional_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol ): - _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device="cuda", dtype=dtype) - _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device="cuda", dtype=dtype) + _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) + _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) q1 = _q.clone().requires_grad_(True) q2 = _q.clone().requires_grad_(True) @@ -120,9 +123,9 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = LlamaRotaryEmbedding(head_dim, device="cuda") + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) - pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0) + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_rope(q1, k1, cos, sin) diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index be7aaef42..0f42e129a 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -10,6 +10,9 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction from liger_kernel.transformers.functional import liger_swiglu from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP +from liger_kernel.utils import infer_device + +device = infer_device() LLAMA_CONFIG = LlamaConfig( hidden_size=4096, @@ -52,22 +55,22 @@ def test_correctness_llamamlp( bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol ): - _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) # initialize weights - G = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - U = torch.randn(hidden_size, intermediate_size, device="cuda", dtype=dtype) - D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + G = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + U = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) - llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + llama_mlp = LlamaMLP(config=LLAMA_CONFIG).to(device).to(dtype) llama_mlp.gate_proj.weight.data = G.T llama_mlp.up_proj.weight.data = U.T llama_mlp.down_proj.weight.data = D.T - liger_mlp = LigerSwiGLUMLP(config=LLAMA_CONFIG).to("cuda").to(dtype) + liger_mlp = LigerSwiGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype) liger_mlp.gate_proj.weight.data = G.T liger_mlp.up_proj.weight.data = U.T liger_mlp.down_proj.weight.data = D.T @@ -132,20 +135,20 @@ def test_correctness_llamamlp( def test_correctness_phi3mlp( bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol ): - _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) # initialize weights - GU = torch.randn(hidden_size, intermediate_size * 2, device="cuda", dtype=dtype) - D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + GU = torch.randn(hidden_size, intermediate_size * 2, device=device, dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) - phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to("cuda").to(dtype) + phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to(device).to(dtype) phi3_mlp.gate_up_proj.weight.data = GU.T phi3_mlp.down_proj.weight.data = D.T - liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to("cuda").to(dtype) + liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to(device).to(dtype) liger_mlp.gate_up_proj.weight.data = GU.T liger_mlp.down_proj.weight.data = D.T @@ -193,8 +196,8 @@ def test_correctness_phi3mlp( ], ) def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol): - _input = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) - _b = torch.randn(bsz, seq_len, size, device="cuda", dtype=dtype) + _input = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) + _b = torch.randn(bsz, seq_len, size, device=device, dtype=dtype) x1 = _input.clone().requires_grad_(True) x2 = _input.clone().requires_grad_(True) diff --git a/test/utils.py b/test/utils.py index 8ac0309fb..b655104f3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -15,21 +15,9 @@ from tokenizers.trainers import BpeTrainer from transformers import PretrainedConfig, PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding +from liger_kernel.utils import infer_device - -def infer_device(): - """ - Get current device name based on available devices - """ - if torch.cuda.is_available(): - return "cuda" - elif torch.xpu.is_available(): - return "xpu" - else: - return "cpu" - - -torch_device = infer_device() +device = infer_device() def set_seed(seed=42): @@ -43,7 +31,7 @@ def set_seed(seed=42): # PyTorch random seed torch.manual_seed(seed) - if torch_device == "cuda": + if device == "cuda": # If you are using CUDA torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. @@ -51,8 +39,8 @@ def set_seed(seed=42): # PyTorch backend settings torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - elif torch_device == "xpu": - # If you ware using intel GPU + elif device == "xpu": + # If you are using XPU torch.xpu.manual_seed(seed) torch.xpu.manual_seed_all(seed) @@ -225,9 +213,9 @@ def train_bpe_tokenizer(special_tokens: List[str], unk_token: str = "<|unk|>"): def supports_bfloat16(): - if torch_device == "cuda": + if device == "cuda": return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer - elif torch_device == "xpu": + elif device == "xpu": return True else: return False From e84dfe09db3ce3bc3b216539e28b0d3c91978e13 Mon Sep 17 00:00:00 2001 From: Golam Rabbani Date: Tue, 19 Nov 2024 09:41:02 -0800 Subject: [PATCH 2/5] add xpu support --- benchmark/scripts/benchmark_cpo_loss.py | 7 ++--- benchmark/scripts/benchmark_qwen2vl_mrope.py | 28 +++++++++++--------- test/chunked_loss/test_cpo_loss.py | 12 ++++++--- test/transformers/test_qwen2vl_mrope.py | 24 ++++++++++------- 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index d10c8da8a..410dd6dcd 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -13,6 +13,10 @@ ) from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction +from liger_kernel.utils import infer_device + + +device = infer_device() sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -66,7 +70,6 @@ def bench_memory_fused_linear_cpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) @@ -107,8 +110,6 @@ def bench_speed_fused_linear_cpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device) diff --git a/benchmark/scripts/benchmark_qwen2vl_mrope.py b/benchmark/scripts/benchmark_qwen2vl_mrope.py index 77ed61921..0b9890e68 100644 --- a/benchmark/scripts/benchmark_qwen2vl_mrope.py +++ b/benchmark/scripts/benchmark_qwen2vl_mrope.py @@ -14,6 +14,10 @@ ) from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +from liger_kernel.utils import infer_device + + +device = infer_device() def bench_speed_qwen2vl_mrope( @@ -40,23 +44,23 @@ def bench_speed_qwen2vl_mrope( ) head_dim = hidden_size // num_q_heads - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k, pos_ids) mrope_section_hw = head_dim * 3 // 16 @@ -133,23 +137,23 @@ def bench_memory_qwen2vl_mrope( ) head_dim = hidden_size // num_q_heads - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) q = torch.randn( (1, seq_len, num_q_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) k = torch.randn( (1, seq_len, num_kv_heads, head_dim), - device="cuda", + device=device, requires_grad=True, dtype=dtype, ).transpose(1, 2) - dq, dk = torch.randn_like(q, device="cuda", dtype=dtype), torch.randn_like( - k, device="cuda" + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device ) - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k, pos_ids) mrope_section_hw = head_dim * 3 // 16 diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 9211f98fd..3345f76c5 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -6,6 +6,10 @@ import torch.nn.functional as F from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction +from liger_kernel.utils import infer_device + + +device = infer_device() # set random seed globally set_seed() @@ -87,7 +91,7 @@ def test_correctness( ): B = 2 * B # cpo loss requires B to be even - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -98,7 +102,7 @@ def test_correctness( B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -106,11 +110,11 @@ def test_correctness( indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None diff --git a/test/transformers/test_qwen2vl_mrope.py b/test/transformers/test_qwen2vl_mrope.py index fb3f4b80e..49b7bd5a5 100644 --- a/test/transformers/test_qwen2vl_mrope.py +++ b/test/transformers/test_qwen2vl_mrope.py @@ -16,6 +16,10 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction from liger_kernel.transformers.functional import liger_qwen2vl_mrope from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +from liger_kernel.utils import infer_device + + +device = infer_device() @pytest.mark.skipif( @@ -49,16 +53,16 @@ def test_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol ): - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) _tensor_q = ( - torch.randn((bsz, seq_len, num_q_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) _tensor_k = ( - torch.randn((bsz, seq_len, num_kv_heads, head_dim), device="cuda") + torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) .transpose(1, 2) .to(dtype) ) @@ -70,7 +74,7 @@ def test_correctness( k2 = _tensor_k.clone().requires_grad_(True) # NOTE: this position ids distribution is different from the real one, just to test op correctness - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -81,8 +85,8 @@ def test_correctness( # validate backward pass dq, dk = ( - torch.randn_like(hf_q, device="cuda"), - torch.randn_like(hf_k, device="cuda").to(dtype), + torch.randn_like(hf_q, device=device), + torch.randn_like(hf_k, device=device).to(dtype), ) q1_grad, k1_grad = torch.autograd.grad( @@ -116,8 +120,8 @@ def test_correctness( def test_functional_correctness( bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol ): - _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device="cuda", dtype=dtype) - _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device="cuda", dtype=dtype) + _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) + _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) q1 = _q.clone().requires_grad_(True) q2 = _q.clone().requires_grad_(True) @@ -125,9 +129,9 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device="cuda") + rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) - pos_ids = torch.arange(seq_len * 3, device="cuda", dtype=torch.long).view(3, 1, -1) + pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section) From 34d40a3cbf10ffab17eecae656b439c9249d3c43 Mon Sep 17 00:00:00 2001 From: "Rabbani, Golam" Date: Wed, 20 Nov 2024 16:27:07 +0000 Subject: [PATCH 3/5] add xpu support to simpo_loss --- benchmark/scripts/benchmark_simpo_loss.py | 7 ++++--- test/chunked_loss/test_simpo_loss.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py index 457f6f2e8..fbc23cf2c 100644 --- a/benchmark/scripts/benchmark_simpo_loss.py +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -13,6 +13,10 @@ ) from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction +from liger_kernel.utils import infer_device + + +device = infer_device() sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -66,7 +70,6 @@ def bench_memory_fused_linear_simpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - device = "cuda" torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) @@ -107,8 +110,6 @@ def bench_speed_fused_linear_simpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - device = "cuda" - torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 727aaa56e..1a2737b2c 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -5,6 +5,10 @@ import torch from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction +from liger_kernel.utils import infer_device + + +device = infer_device() # set random seed globally set_seed() @@ -33,7 +37,7 @@ def test_correctness( ): B = 2 * B # SimPO loss requires B to be even - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -44,7 +48,7 @@ def test_correctness( B, T, ), - device="cuda", + device=device, dtype=torch.long, ) # Assign some random number of elements as ignore_index @@ -52,11 +56,11 @@ def test_correctness( indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - _weight = torch.randn(V, H, device="cuda", dtype=dtype) + _weight = torch.randn(V, H, device=device, dtype=dtype) weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) - _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + _bias = torch.randn(V, device=device, dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None From 90d7113b418dcde7368e4bf7dcf2ff66543cbadf Mon Sep 17 00:00:00 2001 From: "Rabbani, Golam" Date: Fri, 22 Nov 2024 11:52:28 -0800 Subject: [PATCH 4/5] main to xpu-support --- test/chunked_loss/test_orpo_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 47e5a0ed8..4c95634ed 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -209,7 +209,7 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B = 2 * B - _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) input2 = _input.detach().clone().requires_grad_(True) @@ -220,7 +220,7 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): B, T, ), - device="cuda", + device=device, dtype=torch.long, ) From 1748f9c9e11f05f543dc835370ee7b7cb9c3a7e4 Mon Sep 17 00:00:00 2001 From: "Rabbani, Golam" Date: Fri, 22 Nov 2024 20:02:11 +0000 Subject: [PATCH 5/5] replace cuda with device for xpu support --- test/transformers/test_cross_entropy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index eb052ab5f..6e1dc277b 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -736,12 +736,12 @@ def test_float32_internal(): reduction = "mean" # Initialize input tensors - X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device="cuda") - Y = torch.randint(0, n_cols, (batch_size,), device="cuda") + X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device=device) + Y = torch.randint(0, n_cols, (batch_size,), device=device) # Run kernel for bfloat16 X_bf16 = X_init.clone() - loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_bf16, X_stride=X_bf16.stride(-2), @@ -765,7 +765,7 @@ def test_float32_internal(): # Run kernel for float32 X_fp32 = X_init.float() - loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device="cuda") + loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device) liger_cross_entropy_kernel[(batch_size,)]( X_ptr=X_fp32, X_stride=X_fp32.stride(-2),