Skip to content

Commit

Permalink
Merge pull request #129 from osqp/as/torch
Browse files Browse the repository at this point in the history
As/torch
  • Loading branch information
AmitSolomonPrinceton authored Mar 26, 2024
2 parents b0c472c + 00af17b commit c2de60d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 35 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ list(APPEND CMAKE_MESSAGE_INDENT " ")
FetchContent_Declare(
osqp
GIT_REPOSITORY https://github.com/osqp/osqp.git
GIT_TAG vb/henryiii/skbuild
GIT_TAG 02a117cfc8ad21b06c2596603a2046ee61c82786
)
list(POP_BACK CMAKE_MESSAGE_INDENT)
FetchContent_MakeAvailable(osqp)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"qdldl",
"scipy>=0.13.2,<1.12.0",
"setuptools",
"joblib",
]

[project.optional-dependencies]
Expand Down
137 changes: 103 additions & 34 deletions src/osqp/nn/torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import numpy as np
import scipy.sparse as spa
import torch
from torch.nn import Module
from torch.autograd import Function

try:
import torch
from torch.nn import Module
from torch.autograd import Function
except ImportError as e:
print(f'Import Error: {e}')
from joblib import Parallel, delayed
import multiprocessing

import osqp

Expand Down Expand Up @@ -112,6 +118,56 @@ def forward(ctx, P_val, q_val, A_val, l_val, u_val):
"""

def _get_update_flag(n_batch: int) -> bool:
"""
This is a helper function that returns a flag if we need to update the solvers
or generate them. Raises an RuntimeError if the number of solvers is invalid.
"""
num_solvers = len(solvers)
if num_solvers not in (0, n_batch):
raise RuntimeError(
f'Invalid number of solvers: expected 0 or {n_batch},' f' but got {num_solvers}.'
)
return num_solvers == n_batch

def _inner_solve(i, update_flag, q, l, u, P_val, P_idx, A_val, A_idx, solver_type, eps_abs, eps_rel):
"""
This inner function solves for each solver. update_flag has to be passed from
outside to make sure it doesn't change during a parallel run.
"""
# Solve QP
# TODO: Cache solver object in between
# P = spa.csc_matrix((to_numpy(P_val[i]), P_idx), shape=P_shape)
if update_flag:
solver = solvers[i]
solver.update(
q=q[i], l=l[i], u=u[i], Px=to_numpy(P_val[i]), Px_idx=P_idx, Ax=to_numpy(A_val[i]), Ax_idx=A_idx
)
else:
P = spa.csc_matrix((to_numpy(P_val[i]), P_idx), shape=P_shape)
A = spa.csc_matrix((to_numpy(A_val[i]), A_idx), shape=A_shape)
# TODO: Deep copy when available
solver = osqp.OSQP(algebra=algebra)
solver.setup(
P,
q[i],
A,
l[i],
u[i],
solver_type=solver_type,
verbose=verbose,
eps_abs=eps_abs,
eps_rel=eps_rel,
)
result = solver.solve()
status = result.info.status
if status != 'solved':
# TODO: We can replace this with something calmer and
# add some more options around potentially ignoring this.
raise RuntimeError(f'Unable to solve QP, status: {status}')

return solver, result.x

params = [P_val, q_val, A_val, l_val, u_val]

for p in params:
Expand All @@ -138,43 +194,39 @@ def forward(ctx, P_val, q_val, A_val, l_val, u_val):
assert A_val.size(1) == len(A_idx[0]), 'Unexpected size of A'
assert P_val.size(1) == len(P_idx[0]), 'Unexpected size of P'

P = [spa.csc_matrix((to_numpy(P_val[i]), P_idx), shape=P_shape) for i in range(n_batch)]
q = [to_numpy(q_val[i]) for i in range(n_batch)]
A = [spa.csc_matrix((to_numpy(A_val[i]), A_idx), shape=A_shape) for i in range(n_batch)]
l = [to_numpy(l_val[i]) for i in range(n_batch)]
u = [to_numpy(u_val[i]) for i in range(n_batch)]

# Perform forward step solving the QPs
x_torch = torch.zeros((n_batch, n), dtype=dtype, device=device)

x = []
for i in range(n_batch):
# Solve QP
# TODO: Cache solver object in between
solver = osqp.OSQP(algebra=algebra)
solver.setup(
P[i],
q[i],
A[i],
l[i],
u[i],
update_flag = _get_update_flag(n_batch)
n_jobs = multiprocessing.cpu_count()
res = Parallel(n_jobs=n_jobs, prefer='threads')(
delayed(_inner_solve)(
i=i,
update_flag=update_flag,
q=q,
l=l,
u=u,
P_val=P_val,
P_idx=P_idx,
A_val=A_val,
A_idx=A_idx,
solver_type=solver_type,
verbose=verbose,
eps_abs=eps_abs,
eps_rel=eps_rel,
)
result = solver.solve()
solvers.append(solver)
status = result.info.status
if status != 'solved':
# TODO: We can replace this with something calmer and
# add some more options around potentially ignoring this.
raise RuntimeError(f'Unable to solve QP, status: {status}')
x.append(result.x)

# This is silently converting result.x to the same
# dtype and device as x_torch.
x_torch[i] = torch.from_numpy(result.x)
for i in range(n_batch)
)
solvers_loop, x = zip(*res)
for i in range(n_batch):
if update_flag:
solvers[i] = solvers_loop[i]
else:
solvers.append(solvers_loop[i])
x_torch[i] = torch.from_numpy(x[i])

# Return solutions
if not batch_mode:
Expand All @@ -184,6 +236,18 @@ def forward(ctx, P_val, q_val, A_val, l_val, u_val):

@staticmethod
def backward(ctx, dl_dx_val):
def _loop_adjoint_derivative(solver, dl_dx):
"""
This inner function calculates dp[i] dl[i], du[i], dP[i], dA[i]
using solvers[i], dl_dx[i].
"""
solver.adjoint_derivative_compute(dx=dl_dx)
dPi_np, dAi_np = solver.adjoint_derivative_get_mat(as_dense=False, dP_as_triu=False)
dqi_np, dli_np, dui_np = solver.adjoint_derivative_get_vec()
dq, dl, du = [torch.from_numpy(d) for d in [dqi_np, dli_np, dui_np]]
dP, dA = [torch.from_numpy(d.x) for d in [dPi_np, dAi_np]]
return dq, dl, du, dP, dA

dtype = dl_dx_val.dtype
device = dl_dx_val.device

Expand All @@ -208,12 +272,17 @@ def backward(ctx, dl_dx_val):
dl = torch.zeros((n_batch, m), dtype=dtype, device=device)
du = torch.zeros((n_batch, m), dtype=dtype, device=device)

n_jobs = multiprocessing.cpu_count()
res = Parallel(n_jobs=n_jobs, prefer='threads')(
delayed(_loop_adjoint_derivative)(solvers[i], dl_dx[i]) for i in range(n_batch)
)
dq_vec, dl_vec, du_vec, dP_vec, dA_vec = zip(*res)
for i in range(n_batch):
solvers[i].adjoint_derivative_compute(dx=dl_dx[i])
dPi_np, dAi_np = solvers[i].adjoint_derivative_get_mat(as_dense=False, dP_as_triu=False)
dqi_np, dli_np, dui_np = solvers[i].adjoint_derivative_get_vec()
dq[i], dl[i], du[i] = [torch.from_numpy(d) for d in [dqi_np, dli_np, dui_np]]
dP[i], dA[i] = [torch.from_numpy(d.x) for d in [dPi_np, dAi_np]]
dq[i] = dq_vec[i]
dl[i] = dl_vec[i]
du[i] = du_vec[i]
dP[i] = dP_vec[i]
dA[i] = dA_vec[i]

grads = [dP, dq, dA, dl, du]

Expand Down

0 comments on commit c2de60d

Please sign in to comment.