Skip to content

Commit

Permalink
Enable keyword arguments for liger functional (#400)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

This PR enables the keyword arguments of liger functional #368. 

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

1. Warp the Liger Operator Functions (`torch.autograd.Function`) with an
extra layer that can take key word arguments.
2. For each of the liger functions, updating its unit test function
`test_{operator_name}.py::test_correctness_functional` to reflect that
keyword args can be accepted.

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type:  A10G
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Hongpeng Guo <[email protected]>
Co-authored-by: Byron Hsu <[email protected]>
  • Loading branch information
hongpeng-guo and ByronHsu authored Nov 21, 2024
1 parent 998f4e4 commit 317ff43
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 28 deletions.
2 changes: 1 addition & 1 deletion dev/modal/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10)
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
def liger_tests():
import subprocess

Expand Down
139 changes: 127 additions & 12 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,6 @@
from liger_kernel.ops.rope import LigerRopeFunction
from liger_kernel.ops.swiglu import LigerSiLUMulFunction

liger_swiglu = LigerSiLUMulFunction.apply
liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
liger_geglu = LigerGELUMulFunction.apply
liger_rms_norm = LigerRMSNormFunction.apply
liger_rope = LigerRopeFunction.apply
liger_qwen2vl_mrope = LigerQwen2VLMRopeFunction.apply
liger_layer_norm = LigerLayerNormFunction.apply
liger_kl_div = LigerKLDivLossFunction.apply
liger_jsd = LigerJSDFunction.apply
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
liger_group_norm = LigerGroupNormFunction.apply


# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
# `weight` and `size_average` are placeholders and not implemented yet
Expand Down Expand Up @@ -56,3 +44,130 @@ def liger_cross_entropy(
if not return_z_loss:
return loss
return loss, z_loss


def liger_fused_linear_cross_entropy(
input,
weight,
target,
bias=None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
):
return LigerFusedLinearCrossEntropyFunction.apply(
input,
weight,
target,
bias,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
)


def liger_fused_linear_jsd(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels=None,
jsd_beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
return LigerFusedLinearJSDFunction.apply(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
temperature,
)


def liger_geglu(a, b):
return LigerGELUMulFunction.apply(a, b)


def liger_group_norm(
X,
affine_scaling_weight,
affine_shifting_bias,
num_channels,
num_groups,
eps,
):
return LigerGroupNormFunction.apply(
X,
affine_scaling_weight,
affine_shifting_bias,
num_channels,
num_groups,
eps,
)


def liger_jsd(
input,
target,
shift_labels=None,
beta: float = 0.5,
ignore_index: int = -100,
):
return LigerJSDFunction.apply(
input,
target,
shift_labels,
beta,
ignore_index,
)


# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
# `size_average` and `mean` are being deprecated in torch API and are placeholders here
def liger_kl_div(
input,
target,
size_average: bool = True,
reduce: bool = True,
reduction: str = "mean",
log_target: bool = False,
eps: float = 1e-10,
):
# Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
return LigerKLDivLossFunction.apply(
input,
target,
reduction,
log_target,
eps,
)


def liger_layer_norm(X, W, B, eps):
return LigerLayerNormFunction.apply(X, W, B, eps)


def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)


def liger_rms_norm(
X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
):
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)


def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)


def liger_swiglu(a, b):
return LigerSiLUMulFunction.apply(a, b)
7 changes: 6 additions & 1 deletion test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,12 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
weight = torch.randn(V, H, device=device, dtype=dtype)
bias = torch.randn(V, device=device, dtype=dtype) if bias else None

y1 = liger_fused_linear_cross_entropy(x1, weight, target, bias)
y1 = liger_fused_linear_cross_entropy(
input=x1,
weight=weight,
target=target,
bias=bias,
)
y2 = LigerFusedLinearCrossEntropyFunction.apply(x2, weight, target, bias)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
Expand Down
16 changes: 8 additions & 8 deletions test/transformers/test_fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,14 @@ def test_correctness_functional(
label[indices_to_assign] = ignore_index

output1 = liger_fused_linear_jsd(
_input1,
_weight1,
teacher_input,
teacher_weight,
label,
beta,
ignore_index,
temperature,
student_input=_input1,
student_weight=_weight1,
teacher_input=teacher_input,
teacher_weight=teacher_weight,
shift_labels=label,
jsd_beta=beta,
ignore_index=ignore_index,
temperature=temperature,
)
output2 = LigerFusedLinearJSDFunction.apply(
_input2,
Expand Down
2 changes: 1 addition & 1 deletion test/transformers/test_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
b1 = _b.clone().requires_grad_(True)
b2 = _b.clone().requires_grad_(True)

y1 = liger_geglu(x1, b1)
y1 = liger_geglu(a=x1, b=b1)
y2 = LigerGELUMulFunction.apply(x2, b2)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
Expand Down
8 changes: 7 additions & 1 deletion test/transformers/test_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,13 @@ def _test_correctness_functional(
label[indices_to_assign] = ignore_index

output = LigerJSDFunction.apply(x1, target, label, beta, ignore_index)
output2 = liger_jsd(x2, target, label, beta, ignore_index)
output2 = liger_jsd(
input=x2,
target=target,
shift_labels=label,
beta=beta,
ignore_index=ignore_index,
)
assert torch.allclose(output, output2, atol=atol, rtol=rtol)
if (
not is_last_layer
Expand Down
2 changes: 1 addition & 1 deletion test/transformers/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_liger_layer_norm_functional(
b1 = b.clone().requires_grad_(True)
b2 = b.clone().requires_grad_(True)

y1 = liger_layer_norm(x1, w1, b1, 1e-6)
y1 = liger_layer_norm(X=x1, W=w1, B=b1, eps=1e-6)
y2 = LigerLayerNormFunction.apply(x2, w2, b2, 1e-6)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
Expand Down
2 changes: 1 addition & 1 deletion test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_correctness_functional(

w = torch.randn(hd, device=device, dtype=dtype)

y1 = liger_rms_norm(h1, w, 1e-6, offset, casting_mode)
y1 = liger_rms_norm(X=h1, W=w, eps=1e-6, offset=offset, casting_mode=casting_mode)
y2 = LigerRMSNormFunction.apply(h2, w, 1e-6, offset, casting_mode)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
Expand Down
2 changes: 1 addition & 1 deletion test/transformers/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_functional_correctness(
pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0)
cos, sin = rotary_emb(k1, pos_ids)

functional_q, functional_k = liger_rope(q1, k1, cos, sin)
functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin)
class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin)

assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol)
Expand Down
2 changes: 1 addition & 1 deletion test/transformers/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_correctness_functional(bsz, seq_len, size, dtype, atol, rtol):
b1 = _b.clone().requires_grad_(True)
b2 = _b.clone().requires_grad_(True)

y1 = liger_swiglu(x1, b1)
y1 = liger_swiglu(a=x1, b=b1)
y2 = LigerSiLUMulFunction.apply(x2, b2)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
Expand Down

0 comments on commit 317ff43

Please sign in to comment.