diff --git a/CMakeLists.txt b/CMakeLists.txt index 93a847a5..77d033a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ list(APPEND CMAKE_MESSAGE_INDENT " ") FetchContent_Declare( osqp GIT_REPOSITORY https://github.com/osqp/osqp.git - GIT_TAG vb/henryiii/skbuild + GIT_TAG 02a117cfc8ad21b06c2596603a2046ee61c82786 ) list(POP_BACK CMAKE_MESSAGE_INDENT) FetchContent_MakeAvailable(osqp) diff --git a/pyproject.toml b/pyproject.toml index 608d6162..94817a9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "qdldl", "scipy>=0.13.2,<1.12.0", "setuptools", + "joblib", ] [project.optional-dependencies] diff --git a/src/osqp/nn/torch.py b/src/osqp/nn/torch.py index 3728345d..fda1dbc8 100644 --- a/src/osqp/nn/torch.py +++ b/src/osqp/nn/torch.py @@ -1,8 +1,14 @@ import numpy as np import scipy.sparse as spa -import torch -from torch.nn import Module -from torch.autograd import Function + +try: + import torch + from torch.nn import Module + from torch.autograd import Function +except ImportError as e: + print(f'Import Error: {e}') +from joblib import Parallel, delayed +import multiprocessing import osqp @@ -112,6 +118,56 @@ def forward(ctx, P_val, q_val, A_val, l_val, u_val): """ + def _get_update_flag(n_batch: int) -> bool: + """ + This is a helper function that returns a flag if we need to update the solvers + or generate them. Raises an RuntimeError if the number of solvers is invalid. + """ + num_solvers = len(solvers) + if num_solvers not in (0, n_batch): + raise RuntimeError( + f'Invalid number of solvers: expected 0 or {n_batch},' f' but got {num_solvers}.' + ) + return num_solvers == n_batch + + def _inner_solve(i, update_flag, q, l, u, P_val, P_idx, A_val, A_idx, solver_type, eps_abs, eps_rel): + """ + This inner function solves for each solver. update_flag has to be passed from + outside to make sure it doesn't change during a parallel run. + """ + # Solve QP + # TODO: Cache solver object in between + # P = spa.csc_matrix((to_numpy(P_val[i]), P_idx), shape=P_shape) + if update_flag: + solver = solvers[i] + solver.update( + q=q[i], l=l[i], u=u[i], Px=to_numpy(P_val[i]), Px_idx=P_idx, Ax=to_numpy(A_val[i]), Ax_idx=A_idx + ) + else: + P = spa.csc_matrix((to_numpy(P_val[i]), P_idx), shape=P_shape) + A = spa.csc_matrix((to_numpy(A_val[i]), A_idx), shape=A_shape) + # TODO: Deep copy when available + solver = osqp.OSQP(algebra=algebra) + solver.setup( + P, + q[i], + A, + l[i], + u[i], + solver_type=solver_type, + verbose=verbose, + eps_abs=eps_abs, + eps_rel=eps_rel, + ) + result = solver.solve() + status = result.info.status + if status != 'solved': + # TODO: We can replace this with something calmer and + # add some more options around potentially ignoring this. + raise RuntimeError(f'Unable to solve QP, status: {status}') + + return solver, result.x + params = [P_val, q_val, A_val, l_val, u_val] for p in params: @@ -138,43 +194,39 @@ def forward(ctx, P_val, q_val, A_val, l_val, u_val): assert A_val.size(1) == len(A_idx[0]), 'Unexpected size of A' assert P_val.size(1) == len(P_idx[0]), 'Unexpected size of P' - P = [spa.csc_matrix((to_numpy(P_val[i]), P_idx), shape=P_shape) for i in range(n_batch)] q = [to_numpy(q_val[i]) for i in range(n_batch)] - A = [spa.csc_matrix((to_numpy(A_val[i]), A_idx), shape=A_shape) for i in range(n_batch)] l = [to_numpy(l_val[i]) for i in range(n_batch)] u = [to_numpy(u_val[i]) for i in range(n_batch)] # Perform forward step solving the QPs x_torch = torch.zeros((n_batch, n), dtype=dtype, device=device) - x = [] - for i in range(n_batch): - # Solve QP - # TODO: Cache solver object in between - solver = osqp.OSQP(algebra=algebra) - solver.setup( - P[i], - q[i], - A[i], - l[i], - u[i], + update_flag = _get_update_flag(n_batch) + n_jobs = multiprocessing.cpu_count() + res = Parallel(n_jobs=n_jobs, prefer='threads')( + delayed(_inner_solve)( + i=i, + update_flag=update_flag, + q=q, + l=l, + u=u, + P_val=P_val, + P_idx=P_idx, + A_val=A_val, + A_idx=A_idx, solver_type=solver_type, - verbose=verbose, eps_abs=eps_abs, eps_rel=eps_rel, ) - result = solver.solve() - solvers.append(solver) - status = result.info.status - if status != 'solved': - # TODO: We can replace this with something calmer and - # add some more options around potentially ignoring this. - raise RuntimeError(f'Unable to solve QP, status: {status}') - x.append(result.x) - - # This is silently converting result.x to the same - # dtype and device as x_torch. - x_torch[i] = torch.from_numpy(result.x) + for i in range(n_batch) + ) + solvers_loop, x = zip(*res) + for i in range(n_batch): + if update_flag: + solvers[i] = solvers_loop[i] + else: + solvers.append(solvers_loop[i]) + x_torch[i] = torch.from_numpy(x[i]) # Return solutions if not batch_mode: @@ -184,6 +236,18 @@ def forward(ctx, P_val, q_val, A_val, l_val, u_val): @staticmethod def backward(ctx, dl_dx_val): + def _loop_adjoint_derivative(solver, dl_dx): + """ + This inner function calculates dp[i] dl[i], du[i], dP[i], dA[i] + using solvers[i], dl_dx[i]. + """ + solver.adjoint_derivative_compute(dx=dl_dx) + dPi_np, dAi_np = solver.adjoint_derivative_get_mat(as_dense=False, dP_as_triu=False) + dqi_np, dli_np, dui_np = solver.adjoint_derivative_get_vec() + dq, dl, du = [torch.from_numpy(d) for d in [dqi_np, dli_np, dui_np]] + dP, dA = [torch.from_numpy(d.x) for d in [dPi_np, dAi_np]] + return dq, dl, du, dP, dA + dtype = dl_dx_val.dtype device = dl_dx_val.device @@ -208,12 +272,17 @@ def backward(ctx, dl_dx_val): dl = torch.zeros((n_batch, m), dtype=dtype, device=device) du = torch.zeros((n_batch, m), dtype=dtype, device=device) + n_jobs = multiprocessing.cpu_count() + res = Parallel(n_jobs=n_jobs, prefer='threads')( + delayed(_loop_adjoint_derivative)(solvers[i], dl_dx[i]) for i in range(n_batch) + ) + dq_vec, dl_vec, du_vec, dP_vec, dA_vec = zip(*res) for i in range(n_batch): - solvers[i].adjoint_derivative_compute(dx=dl_dx[i]) - dPi_np, dAi_np = solvers[i].adjoint_derivative_get_mat(as_dense=False, dP_as_triu=False) - dqi_np, dli_np, dui_np = solvers[i].adjoint_derivative_get_vec() - dq[i], dl[i], du[i] = [torch.from_numpy(d) for d in [dqi_np, dli_np, dui_np]] - dP[i], dA[i] = [torch.from_numpy(d.x) for d in [dPi_np, dAi_np]] + dq[i] = dq_vec[i] + dl[i] = dl_vec[i] + du[i] = du_vec[i] + dP[i] = dP_vec[i] + dA[i] = dA_vec[i] grads = [dP, dq, dA, dl, du]