diff --git a/src/brevitas/optim/cailey_sgd.py b/src/brevitas/optim/cailey_sgd.py index 2e2426fee..265fb2c75 100644 --- a/src/brevitas/optim/cailey_sgd.py +++ b/src/brevitas/optim/cailey_sgd.py @@ -43,7 +43,10 @@ def Cayley_loop(X, W, tan_vec, t): # def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n [p, n] = tan_vec.size() tan_vec.t_() - q, r = torch.linalg.qr(tan_vec) + dtype = tan_vec.dtype + # torch.linalg.qr is not implemented for 'Half' + q, r = torch.linalg.qr(tan_vec.to(torch.float32)) + q, r = q.to(dtype=dtype), r.to(dtype=dtype) d = torch.diag(r, 0) ph = d.sign() q *= ph.expand_as(q) @@ -87,7 +90,7 @@ class CaileySGD(Optimizer): def __init__( self, params, - lr: float = 1e-3, + lr: float = 1e-1, momentum: int = 0, dampening: int = 0, weight_decay: int = 0, @@ -150,9 +153,7 @@ def step(self, closure=None): param_state = self.state[p] if "momentum_buffer" not in param_state: - param_state["momentum_buffer"] = torch.zeros(g.t().size()) - if p.is_cuda: - param_state["momentum_buffer"] = param_state["momentum_buffer"].cuda() + param_state["momentum_buffer"] = torch.zeros_like(g.t()) V = param_state["momentum_buffer"] V = momentum * V - g.t() diff --git a/tests/brevitas/optim/test_cailey_sgd.py b/tests/brevitas/optim/test_cailey_sgd.py index 92de8ae5a..40b6a4675 100644 --- a/tests/brevitas/optim/test_cailey_sgd.py +++ b/tests/brevitas/optim/test_cailey_sgd.py @@ -42,22 +42,19 @@ from copy import deepcopy from itertools import product -import math -import sys -from typing import List, Union -import unittest from hypothesis import given import numpy as np +from packaging import version import pytest import pytest_cases from pytest_cases import fixture from scipy.stats import ortho_group import torch from torch.nn import Parameter -import torch.nn as nn from torch.optim.lr_scheduler import LinearLR +from brevitas import torch_version from brevitas.optim.cailey_sgd import CaileySGD from tests.conftest import SEED @@ -65,14 +62,14 @@ OPTIMIZER_KWARGS = [{ "stiefel": True}, { - "stiefel": True, "lr": 1e-2}, { - "stiefel": True, "lr": torch.tensor(0.001)}] + "stiefel": True, "lr": 0.5}, { + "stiefel": True, "lr": torch.tensor(0.5)}] LR_SCHEDULER_ARGS = [ None, (LinearLR, { "start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),] DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] -DTYPES = [torch.float32] +DTYPES = ["float32", "float16", "bfloat16"] device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES))) @@ -83,7 +80,13 @@ class TestCaileySGD: @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + if torch_version < version.parse('2.3.1') and dtype in ["float16", "bfloat16"]: + pytest.skip( + "Some operations in the CaileySGD optimizer (e.g. diag, eye) are not implemented for 'Half' or 'BFloat16' in PyTorch versions under 2.3.1." + ) + torch.manual_seed(SEED) optim_cls = CaileySGD + dtype = getattr(torch, dtype) # Generate a random orthogonal matrix of size NxN. Columns represent orthonormal vector in R^{N} N = 5 P = 3 @@ -108,6 +111,8 @@ def closure(): return loss initial_value = closure().item() + ATOL = 1e-2 if dtype == torch.float32 else 1e-1 + RTOL = 1e-3 if dtype == torch.float32 else 1e-2 for _ in range(20): closure() optimizer.step() @@ -116,10 +121,11 @@ def closure(): # Verify that iterates stay within the Stiefel manifold assert torch.allclose( - weight.detach().cpu() @ weight.detach().cpu().t(), - torch.eye(P, P, device=device, dtype=dtype).detach().cpu(), - atol=1e-5, - rtol=1e-6) + weight.to(dtype=torch.float32).detach().cpu() + @ weight.to(dtype=torch.float32).detach().cpu().t(), + torch.eye(P, P, device=device, dtype=torch.float32).detach().cpu(), + atol=ATOL, + rtol=RTOL) if optimizer_kwargs.get("maximize", False): assert closure().item() > initial_value