Skip to content

Commit

Permalink
Implements vectorized MINRES (#2)
Browse files Browse the repository at this point in the history
* Work on block minres

* Adds multiple rhs support to minres
  • Loading branch information
PTNobel authored Apr 20, 2024
1 parent a66bb0e commit 06e500b
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 67 deletions.
54 changes: 30 additions & 24 deletions linops/block_minres.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,62 @@ def bminres(A, B, M=None, X0=None, tol=1e-5, maxiters=None, verbose=True):
#if M is None:
# M = lo.IdentityOperator(m)
if maxiters is None:
maxiters = 5 * n
maxiters = n
if X0 is None:
X0 = torch.zeros_like(B)
else:
X = X0

k = 1

R_hat_km1 = B - A @ X0
R_hat_0 = B - A @ X0

X_hat = X0
Z_km1 = torch.zeros_like(X0)
Z_k, nu_k_inv = torch.qr(R_hat_km1)
RHS_phi = nu_k_inv
phi_k = torch.zeros(s)
Z_k, nu_k_inv = torch.linalg.qr(R_hat_0)
RHS_psi = nu_k_inv
psi_km1 = torch.zeros(s, s)
W_bar_k = Z_k
rho_km1 = 0
#rho_km1 = 0
AZ_k = A @ Z_k
rho_k = Z_k.T @ AZ_k
L_bar_k_k = rho_k
V_kp1 = torch.eye(2 * s)
V_k_T = torch.eye(2 * s)

for k in range(1, maxiters):
Z_k, nu_k_inv, Z_km1 = torch.qr(AZ_k - Z_k @ rho_k - Z_km1 @ nu_k_inv.T), Z_k
Z_k, nu_k_inv, Z_km1 = *torch.linalg.qr(AZ_k - Z_k @ rho_k - Z_km1 @ nu_k_inv.T), Z_k
print(torch.linalg.norm(torch.eye(s) - Z_k.T @ Z_k))
AZ_k = A @ Z_k
rho_k = Z_k @ AZ_k
temp = nu_k_inv @ V_k.T[s:, :]
L_k_km2, L_bar_kp1_k = temp[:s, :], temp[s:, :]
rho_k = Z_k.T @ AZ_k
temp = nu_k_inv @ V_k_T[s:, :] # Double check this step
L_k_km2, L_bar_kp1_k = temp[:, :s], temp[:, s:]

temp = torch.vstack([L_bar_k_k.T, nu_k_inv])
V_kp1_T, L_k_k_T_aug = torch.qr(temp, mode='complete')
V_k_T, L_k_k_T_aug = torch.linalg.qr(temp, 'complete') # V_k_T should not be orthogonal. V_k should be.
L_k_k = L_k_k_T_aug[:s, :].T
print(torch.linalg.norm(torch.eye(2 * s) - V_k_T @ V_k_T.T))
print(torch.linalg.norm(torch.eye(2 * s) - V_k_T.T @ V_k_T))

temp = torch.hstack([L_bar_kp1_k, rho_k]) @ V_kp1_T
L_kp1_k, L_bar_k_k = temp[s:, :], temp[:s, :]
temp = torch.hstack([L_bar_kp1_k, rho_k]) @ V_k_T
L_kp1_k, L_bar_k_k = temp[:, s:], temp[:, :s]


temp = torch.hstack([W_bar_k, Z_k]) @ V_kp1_T
W_k, W_bar_k = temp[s:, :], temp[:s, :]
temp = torch.hstack([W_bar_k, Z_k]) @ V_k_T
W_k, W_bar_k = temp[:, s:], temp[:, :s]

# TODO: Check if better solve; should be triangular
phi_km1 = phi_k
phi_k = torch.linalg.solve(L_k_k, RHS_phi)
RHS_phi = -(L_kp1_k @ phi_k + L_k_km2 @ phi_km1)
psi_k = torch.linalg.solve(L_k_k, RHS_psi)
RHS_psi = -(L_kp1_k @ psi_k + L_k_km2 @ psi_km1)

X_hat += W_k @ psi_k
psi_km1 = psi_k
del psi_k

X_hat += W_k @ phi_k
psi_bar_kp1 = torch.linalg.solve(L_bar_k_k, RHS_psi)
X = X_hat + W_bar_k @ psi_bar_kp1
print(k, torch.linalg.norm(B - A @ X) / torch.linalg.norm(B))

# TODO: Check if better solve; should be triangular
phi_bar_kp1 = torch.linalg.solve(L_bar_k_k, RHS_phi)
return X_hat + W_bar_k @ phi_bar_kp1, k + 1
psi_bar_kp1 = torch.linalg.solve(L_bar_k_k, RHS_psi)
return X_hat + W_bar_k @ psi_bar_kp1, k + 1



Expand Down
88 changes: 50 additions & 38 deletions linops/minres.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,59 +49,68 @@ def minres(A, b, M=None, x0=None, tol=1e-5, maxiters=None, verbose=True):
else:
x = x0
iters = 0
Anorm = 0
Acond = 0
eps = torch.tensor(torch.finfo(b.dtype).eps, device=x.device)
# Stopping criterion only
Anorm = 0 # scalar

r1 = b - A @ x
y = M @ r1
# Stopping criterion only
Acond = 0 # scalar
eps = torch.tensor(torch.finfo(b.dtype).eps, device=x.device) # scalar

beta1 = r1 @ y
assert beta1 >= 0
if beta1 == 0:
r1 = b - A @ x # Supports block
y = M @ r1 # Supports block

#beta1 = r1 @ y # Needs modification
beta1 = inner(r1, y)
assert beta1.min() >= 0, "M must be PD"
if (beta1 == 0).all():
return x
bnorm = torch.linalg.vector_norm(b)
if bnorm == 0:
return b

beta1 = torch.sqrt(beta1)
beta1 = torch.sqrt(beta1) # Supports block

oldb = torch.zeros_like(beta1) # Supports block
beta = beta1 # Supports block
dbar = torch.zeros_like(beta1) # Supports block
epsilon = torch.zeros_like(beta1) # Supports block
phibar = beta1 # Supports block
rhs1 = beta1 # Supports block
rhs2 = torch.zeros_like(beta1) # Supports block

oldb = 0
beta = beta1
dbar = 0
epsilon = 0
phibar = beta1
rhs1 = beta1
rhs2 = 0
tnorm2 = torch.tensor(0.0, device=x.device)
# Stopping criterion only
tnorm2 = torch.zeros_like(beta1)
# Only for illconditioning detection
gmax = torch.tensor(0.0, device=x.device)
# Only for illconditioning detection
gmin = torch.tensor(torch.finfo(x.dtype).max, device=x.device)
cs = -1
sn = 0
w = torch.zeros_like(b)
w2 = torch.zeros_like(b)
r2 = r1

cs = -torch.ones_like(beta1)
sn = torch.zeros_like(beta1) # Supports block
w = torch.zeros_like(b) # Supports block
w2 = torch.zeros_like(b) # Supports block
r2 = r1 # Supports block

while iters < maxiters:
iters += 1
s = 1.0 / beta
v = s * y
y = A @ v
s = 1.0 / beta # Supports block
v = s * y # Supports block
y = A @ v # Supports block
if iters >= 2:
y = y - (beta / oldb) * r1

alpha = v @ y
alpha = inner(v, y)
y = y - (alpha / beta) * r2
r1 = r2
r2 = y
y = M @ r2
oldb = beta
beta = r2 @ y
assert beta >= 0
beta = inner(r2, y)
assert (beta >= 0).all()
beta = torch.sqrt(beta)
tnorm2 += alpha**2 + oldb**2 + beta**2
if iters == 1:
if beta / beta1 <= 10 * eps:
if (beta / beta1).min() <= 10 * eps:
#assert False, "I think this occurs when A = c * I"
pass

Expand All @@ -127,19 +136,19 @@ def minres(A, b, M=None, x0=None, tol=1e-5, maxiters=None, verbose=True):
w = (v - oldeps * w1 - delta * w2) * denom
x = x + phi * w

gmax = torch.max(gmax, gamma)
gmin = torch.min(gmin, gamma)
gmax = torch.max(gmax, gamma).max()
gmin = torch.min(gmin, gamma).min()
z = rhs1 / gamma
rhs1 = rhs2 - delta * z
rhs2 = - epsilon * z

Anorm = torch.sqrt(tnorm2)
Anorm = torch.sqrt(tnorm2.max())
ynorm = torch.linalg.norm(x)
epsa = Anorm * eps
epsx = Anorm * ynorm * eps
#epsr = Anorm * ynorm * tol
diag = gbar
if diag == 0:
if (diag == 0).any():
diag = epsa
qrnorm = phibar
rnorm = qrnorm
Expand All @@ -153,22 +162,25 @@ def minres(A, b, M=None, x0=None, tol=1e-5, maxiters=None, verbose=True):
else:
test2 = root / Anorm
Acond = gmax / gmin
t1 = 1 + test1
t2 = 1 + test2
t1 = 1 + test1.max()
t2 = 1 + test2.max()
if t2 <= 1:
break
if t1 <= 1:
break

if Acond >= 0.1 / eps:
assert False, "System is ill-conditioned."
if epsx >= beta1:
if (epsx >= beta1).any():
assert False
if test2 <= tol:
if test2.max() <= tol:
break
if test1 <= tol:
if test1.max() <= tol:
break
return x

def L2norm2entry(x, y):
return torch.sqrt(x**2 + y**2)

def inner(x, y):
return (x * y).sum(axis=0) # Supports block
23 changes: 23 additions & 0 deletions tests/test_bminres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import linops as lo
import linops.block_minres
import scipy.sparse as sp
import scipy.sparse.linalg

def to_implement_test():
A_tensor = torch.Tensor(
[[3., 0., 0],
[0., 80., 0],
[0., 0, -7.]]
)
A = lo.MatrixOperator(
A_tensor
)
b = torch.Tensor([[5.], [7.], [2.]])
y_lo = lo.block_minres.bminres(A, b)

A = A_tensor.numpy()
b = b.numpy()
y_sp = sp.linalg.minres(A, b, show=True, maxiter=10)
print(y_lo)
print(y_sp)
12 changes: 7 additions & 5 deletions tests/test_minres.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def to_implement_test():
A = lo.MatrixOperator(
A_tensor
)
b = torch.Tensor([5., 7.])
b = torch.Tensor([[5., 7.], [3., 6]])
y_lo = lo.minres.minres(A, b)

A = A_tensor.numpy()
b = b.numpy()
y_sp = sp.linalg.minres(A, b)
#A = A_tensor.numpy()
#b = b.numpy()
#y_sp = sp.linalg.minres(A, b)
print(y_lo)
print(y_sp)
print(torch.linalg.vector_norm(A @ y_lo - b))

to_implement_test()

0 comments on commit 06e500b

Please sign in to comment.