Skip to content

[XPU] Implemented 32bit optimizers in triton #1710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 251 additions & 1 deletion bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from math import prod
from math import prod, sqrt
from typing import Optional

import torch
Expand Down Expand Up @@ -301,3 +301,253 @@ def _(
B_dq,
bias=None,
)


MOMENTUM = 0
RMSPROP = 1
ADAGRAD = 2
ADAM = 3
# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
LION = 4
ADEMAMIX = 5

name2optimizer_id = {
"momentum": MOMENTUM,
"rmsprop": RMSPROP,
"adagrad": ADAGRAD,
"adam": ADAM,
"lion": LION,
"ademamix": ADEMAMIX,
}

@torch.compile
def _optimizer_precondition_32bit(
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: torch.Tensor,
beta1: float,
beta2: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
optimizer_id: int,
):
"""Preprocessing optimizer, computing update norm"""

g_vals = gnorm_scale * g

if optimizer_id == 3: # ADAM
correction1 = 1.0 / (1.0 - beta1**step)
correction2 = 1.0 / (1.0 - beta2**step)

s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals

s1_vals = s1_vals * correction1
s2_vals = s2_vals * correction2

update_vals = s1_vals / (torch.sqrt(s2_vals) + eps)
update_norm = update_vals * update_vals

elif optimizer_id == 5: # ADEMAMIX
update_norm = state1

elif optimizer_id == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = state1 * beta1 + g_vals
update_norm = s1_vals * s1_vals

elif optimizer_id == 4: # LION
s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
update_norm = s1_vals

elif optimizer_id == 1: # RMSPROP
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals

elif optimizer_id == 2: # ADAGRAD
s1_vals = state1 + g_vals * g_vals
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals

total_norm = torch.sum(update_norm)
unorm_vec.add_(total_norm)


@torch.compile
def _optimizer_update_32bit(
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
optimizer_id: int,
):
"""Unified optimizer update kernel"""

p_vals = p.float()
g_vals = (gnorm_scale * g).float()
if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0:
g_vals = g_vals + p_vals * weight_decay

update_scale = 1.0
if max_unorm > 0.0:
current_unorm = torch.sqrt(unorm_vec)
if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers
if current_unorm > max_unorm * param_norm + eps:
update_scale = (max_unorm * param_norm + eps) / current_unorm
else: # 2-state optimizers
if current_unorm > max_unorm * param_norm:
update_scale = (max_unorm * param_norm) / current_unorm

if optimizer_id == 3: # ADAM
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals

correction1 = 1.0 - beta1**step
correction2 = sqrt(1.0 - beta2**step)
step_size = -lr * correction2 / correction1

if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)

update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2))
p_vals = p_vals + update_val

state1.copy_(s1_vals)
state2.copy_(s2_vals)

elif optimizer_id == 5: # ADEMAMIX
s1_vals = state1[0]
s3_vals = state1[1]
s2_vals = state2

m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals
m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals
nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals

correction1 = 1.0 - beta1**step
correction2 = sqrt(1.0 - beta2**step)

if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)

mixed_momentum = (m1 / correction1) + (alpha * m2)
adaptive_term = (torch.sqrt(nu) / correction2) + eps
p_vals = p_vals - lr * (mixed_momentum / adaptive_term)

state1[0].copy_(m1)
state1[1].copy_(m2)
state2.copy_(nu)

elif optimizer_id == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = state1 * beta1 + g_vals

update_val = update_scale * (-lr * s1_vals)
p_vals = p_vals + update_val

state1.copy_(s1_vals)

elif optimizer_id == 4: # LION
momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals
update_val = update_scale * lr * torch.sign(momentum_update)
p_vals = p_vals - update_val

s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
state1.copy_(s1_vals)

elif optimizer_id == 1: # RMSPROP
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val

state1.copy_(s1_vals)

elif optimizer_id == 2: # ADAGRAD
s1_vals = state1 + g_vals * g_vals
update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val

state1.copy_(s1_vals)

p.copy_(p_vals)


@register_kernel("bitsandbytes::optimizer_update_32bit", "default")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
"""
32-bit optimizer implemented by PyTorch with @torch.compile
"""
if skip_zeros:
raise NotImplementedError("skip_zeros is not supported yet")

optimizer_id = name2optimizer_id[optimizer_name]

if optimizer_name == "lion":
_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)

if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec,
beta1, beta2, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)

_optimizer_update_32bit(
g, p, state1, state2, unorm_vec, max_unorm, param_norm,
beta1, beta2, beta3, alpha, eps, weight_decay, step,
lr, gnorm_scale, optimizer_id
)
Loading