Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (optim/cailey_sgd): fix cailey sgd in float16/bfloat16 #1193

Merged
merged 5 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/brevitas/optim/cailey_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def Cayley_loop(X, W, tan_vec, t): #
def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n
[p, n] = tan_vec.size()
tan_vec.t_()
q, r = torch.linalg.qr(tan_vec)
dtype = tan_vec.dtype
# torch.linalg.qr is not implemented for 'Half'
q, r = torch.linalg.qr(tan_vec.to(torch.float32))
q, r = q.to(dtype=dtype), r.to(dtype=dtype)
d = torch.diag(r, 0)
ph = d.sign()
q *= ph.expand_as(q)
Expand Down Expand Up @@ -87,7 +90,7 @@ class CaileySGD(Optimizer):
def __init__(
self,
params,
lr: float = 1e-3,
lr: float = 1e-1,
momentum: int = 0,
dampening: int = 0,
weight_decay: int = 0,
Expand Down Expand Up @@ -150,9 +153,7 @@ def step(self, closure=None):

param_state = self.state[p]
if "momentum_buffer" not in param_state:
param_state["momentum_buffer"] = torch.zeros(g.t().size())
if p.is_cuda:
param_state["momentum_buffer"] = param_state["momentum_buffer"].cuda()
param_state["momentum_buffer"] = torch.zeros_like(g.t())

V = param_state["momentum_buffer"]
V = momentum * V - g.t()
Expand Down
30 changes: 18 additions & 12 deletions tests/brevitas/optim/test_cailey_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,37 +42,34 @@

from copy import deepcopy
from itertools import product
import math
import sys
from typing import List, Union
import unittest

from hypothesis import given
import numpy as np
from packaging import version
import pytest
import pytest_cases
from pytest_cases import fixture
from scipy.stats import ortho_group
import torch
from torch.nn import Parameter
import torch.nn as nn
from torch.optim.lr_scheduler import LinearLR

from brevitas import torch_version
from brevitas.optim.cailey_sgd import CaileySGD
from tests.conftest import SEED

torch.manual_seed(SEED)

OPTIMIZER_KWARGS = [{
"stiefel": True}, {
"stiefel": True, "lr": 1e-2}, {
"stiefel": True, "lr": torch.tensor(0.001)}]
"stiefel": True, "lr": 0.5}, {
"stiefel": True, "lr": torch.tensor(0.5)}]
LR_SCHEDULER_ARGS = [
None,
(LinearLR, {
"start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),]
DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
DTYPES = [torch.float32]
DTYPES = ["float32", "float16", "bfloat16"]

device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES)))

Expand All @@ -83,7 +80,13 @@ class TestCaileySGD:
@pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS)
@pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS)
def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args):
if torch_version < version.parse('2.3.1') and dtype in ["float16", "bfloat16"]:
pytest.skip(
"Some operations in the CaileySGD optimizer (e.g. diag, eye) are not implemented for 'Half' or 'BFloat16' in PyTorch versions under 2.3.1."
)
torch.manual_seed(SEED)
optim_cls = CaileySGD
dtype = getattr(torch, dtype)
# Generate a random orthogonal matrix of size NxN. Columns represent orthonormal vector in R^{N}
N = 5
P = 3
Expand All @@ -108,6 +111,8 @@ def closure():
return loss

initial_value = closure().item()
ATOL = 1e-2 if dtype == torch.float32 else 1e-1
RTOL = 1e-3 if dtype == torch.float32 else 1e-2
for _ in range(20):
closure()
optimizer.step()
Expand All @@ -116,10 +121,11 @@ def closure():

# Verify that iterates stay within the Stiefel manifold
assert torch.allclose(
weight.detach().cpu() @ weight.detach().cpu().t(),
torch.eye(P, P, device=device, dtype=dtype).detach().cpu(),
atol=1e-5,
rtol=1e-6)
weight.to(dtype=torch.float32).detach().cpu()
@ weight.to(dtype=torch.float32).detach().cpu().t(),
torch.eye(P, P, device=device, dtype=torch.float32).detach().cpu(),
atol=ATOL,
rtol=RTOL)

if optimizer_kwargs.get("maximize", False):
assert closure().item() > initial_value
Expand Down